Support typing.Literal as type of tool parameters or return value (#39633)
* support `typing.Literal` as type of tool parameters * validate the `args` of `typing.Literal` roughly * add test to get json schema for `typing.Literal` type hint * fix: add `"type"` attribute to the parsed result of `typing.Literal` * test: add argument `booleanish` to test multi-type literal * style: auto fixup
This commit is contained in:
@@ -20,7 +20,16 @@ from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from inspect import isfunction
|
||||
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -75,7 +84,7 @@ class DocstringParsingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_json_schema_type(param_type: str) -> dict[str, str]:
|
||||
def _get_json_schema_type(param_type: type) -> dict[str, str]:
|
||||
type_mapping = {
|
||||
int: {"type": "integer"},
|
||||
float: {"type": "number"},
|
||||
@@ -119,6 +128,20 @@ def _parse_type_hint(hint: str) -> dict:
|
||||
return_dict["nullable"] = True
|
||||
return return_dict
|
||||
|
||||
elif origin is Literal and len(args) > 0:
|
||||
LITERAL_TYPES = (int, float, str, bool, type(None))
|
||||
args_types = []
|
||||
for arg in args:
|
||||
if type(arg) not in LITERAL_TYPES:
|
||||
raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
|
||||
arg_type = _get_json_schema_type(type(arg)).get("type")
|
||||
if arg_type is not None and arg_type not in args_types:
|
||||
args_types.append(arg_type)
|
||||
return {
|
||||
"type": args_types.pop() if len(args_types) == 1 else list(args_types),
|
||||
"enum": list(args),
|
||||
}
|
||||
|
||||
elif origin is list:
|
||||
if not args:
|
||||
return {"type": "array"}
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
|
||||
@@ -384,6 +384,49 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_literal(self):
|
||||
def fn(
|
||||
temperature_format: Literal["celsius", "fahrenheit"],
|
||||
booleanish: Literal[True, False, 0, 1, "y", "n"] = False,
|
||||
):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
temperature_format: The temperature format to use
|
||||
booleanish: A value that can be regarded as boolean
|
||||
|
||||
|
||||
Returns:
|
||||
The temperature
|
||||
"""
|
||||
return -40.0
|
||||
|
||||
# Let's see if that gets correctly parsed as an enum
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature_format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature format to use",
|
||||
},
|
||||
"booleanish": {
|
||||
"type": ["boolean", "integer", "string"],
|
||||
"enum": [True, False, 0, 1, "y", "n"],
|
||||
"description": "A value that can be regarded as boolean",
|
||||
},
|
||||
},
|
||||
"required": ["temperature_format"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiline_docstring_with_types(self):
|
||||
def fn(x: int, y: int):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user