Support typing.Literal as type of tool parameters or return value (#39633)

* support `typing.Literal` as type of tool parameters

* validate the `args` of `typing.Literal` roughly

* add test to get json schema for `typing.Literal` type hint

* fix: add `"type"` attribute to the parsed result of `typing.Literal`

* test: add argument `booleanish` to test multi-type literal

* style: auto fixup
This commit is contained in:
Park Woorak
2025-07-26 02:51:28 +09:00
committed by GitHub
parent 300d42a43e
commit 3e4d584a5b
2 changed files with 69 additions and 3 deletions

View File

@@ -20,7 +20,16 @@ from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from inspect import isfunction
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
get_args,
get_origin,
get_type_hints,
)
from packaging import version
@@ -75,7 +84,7 @@ class DocstringParsingException(Exception):
pass
def _get_json_schema_type(param_type: str) -> dict[str, str]:
def _get_json_schema_type(param_type: type) -> dict[str, str]:
type_mapping = {
int: {"type": "integer"},
float: {"type": "number"},
@@ -119,6 +128,20 @@ def _parse_type_hint(hint: str) -> dict:
return_dict["nullable"] = True
return return_dict
elif origin is Literal and len(args) > 0:
LITERAL_TYPES = (int, float, str, bool, type(None))
args_types = []
for arg in args:
if type(arg) not in LITERAL_TYPES:
raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
arg_type = _get_json_schema_type(type(arg)).get("type")
if arg_type is not None and arg_type not in args_types:
args_types.append(arg_type)
return {
"type": args_types.pop() if len(args_types) == 1 else list(args_types),
"enum": list(args),
}
elif origin is list:
if not args:
return {"type": "array"}

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import unittest
from typing import Optional, Union
from typing import Literal, Optional, Union
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
@@ -384,6 +384,49 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
self.assertEqual(schema["function"], expected_schema)
def test_literal(self):
def fn(
temperature_format: Literal["celsius", "fahrenheit"],
booleanish: Literal[True, False, 0, 1, "y", "n"] = False,
):
"""
Test function
Args:
temperature_format: The temperature format to use
booleanish: A value that can be regarded as boolean
Returns:
The temperature
"""
return -40.0
# Let's see if that gets correctly parsed as an enum
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"temperature_format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature format to use",
},
"booleanish": {
"type": ["boolean", "integer", "string"],
"enum": [True, False, 0, 1, "y", "n"],
"description": "A value that can be regarded as boolean",
},
},
"required": ["temperature_format"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiline_docstring_with_types(self):
def fn(x: int, y: int):
"""