Support typing.Literal as type of tool parameters or return value (#39633)
* support `typing.Literal` as type of tool parameters * validate the `args` of `typing.Literal` roughly * add test to get json schema for `typing.Literal` type hint * fix: add `"type"` attribute to the parsed result of `typing.Literal` * test: add argument `booleanish` to test multi-type literal * style: auto fixup
This commit is contained in:
@@ -20,7 +20,16 @@ from contextlib import contextmanager
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
|
get_type_hints,
|
||||||
|
)
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -75,7 +84,7 @@ class DocstringParsingException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_json_schema_type(param_type: str) -> dict[str, str]:
|
def _get_json_schema_type(param_type: type) -> dict[str, str]:
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
int: {"type": "integer"},
|
int: {"type": "integer"},
|
||||||
float: {"type": "number"},
|
float: {"type": "number"},
|
||||||
@@ -119,6 +128,20 @@ def _parse_type_hint(hint: str) -> dict:
|
|||||||
return_dict["nullable"] = True
|
return_dict["nullable"] = True
|
||||||
return return_dict
|
return return_dict
|
||||||
|
|
||||||
|
elif origin is Literal and len(args) > 0:
|
||||||
|
LITERAL_TYPES = (int, float, str, bool, type(None))
|
||||||
|
args_types = []
|
||||||
|
for arg in args:
|
||||||
|
if type(arg) not in LITERAL_TYPES:
|
||||||
|
raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
|
||||||
|
arg_type = _get_json_schema_type(type(arg)).get("type")
|
||||||
|
if arg_type is not None and arg_type not in args_types:
|
||||||
|
args_types.append(arg_type)
|
||||||
|
return {
|
||||||
|
"type": args_types.pop() if len(args_types) == 1 else list(args_types),
|
||||||
|
"enum": list(args),
|
||||||
|
}
|
||||||
|
|
||||||
elif origin is list:
|
elif origin is list:
|
||||||
if not args:
|
if not args:
|
||||||
return {"type": "array"}
|
return {"type": "array"}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||||
|
|
||||||
@@ -384,6 +384,49 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(schema["function"], expected_schema)
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_literal(self):
|
||||||
|
def fn(
|
||||||
|
temperature_format: Literal["celsius", "fahrenheit"],
|
||||||
|
booleanish: Literal[True, False, 0, 1, "y", "n"] = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature_format: The temperature format to use
|
||||||
|
booleanish: A value that can be regarded as boolean
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The temperature
|
||||||
|
"""
|
||||||
|
return -40.0
|
||||||
|
|
||||||
|
# Let's see if that gets correctly parsed as an enum
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"temperature_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature format to use",
|
||||||
|
},
|
||||||
|
"booleanish": {
|
||||||
|
"type": ["boolean", "integer", "string"],
|
||||||
|
"enum": [True, False, 0, 1, "y", "n"],
|
||||||
|
"description": "A value that can be regarded as boolean",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["temperature_format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
def test_multiline_docstring_with_types(self):
|
def test_multiline_docstring_with_types(self):
|
||||||
def fn(x: int, y: int):
|
def fn(x: int, y: int):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user