diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index e30b26f0cb..5b71318794 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -20,7 +20,16 @@ from contextlib import contextmanager from datetime import datetime from functools import lru_cache from inspect import isfunction -from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints +from typing import ( + Any, + Callable, + Literal, + Optional, + Union, + get_args, + get_origin, + get_type_hints, +) from packaging import version @@ -75,7 +84,7 @@ class DocstringParsingException(Exception): pass -def _get_json_schema_type(param_type: str) -> dict[str, str]: +def _get_json_schema_type(param_type: type) -> dict[str, str]: type_mapping = { int: {"type": "integer"}, float: {"type": "number"}, @@ -119,6 +128,20 @@ def _parse_type_hint(hint: str) -> dict: return_dict["nullable"] = True return return_dict + elif origin is Literal and len(args) > 0: + LITERAL_TYPES = (int, float, str, bool, type(None)) + args_types = [] + for arg in args: + if type(arg) not in LITERAL_TYPES: + raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.") + arg_type = _get_json_schema_type(type(arg)).get("type") + if arg_type is not None and arg_type not in args_types: + args_types.append(arg_type) + return { + "type": args_types.pop() if len(args_types) == 1 else list(args_types), + "enum": list(args), + } + elif origin is list: if not args: return {"type": "array"} diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 89a2dfa22e..52926649f8 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from typing import Optional, Union +from typing import Literal, Optional, Union from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema @@ -384,6 +384,49 @@ class JsonSchemaGeneratorTest(unittest.TestCase): self.assertEqual(schema["function"], expected_schema) + def test_literal(self): + def fn( + temperature_format: Literal["celsius", "fahrenheit"], + booleanish: Literal[True, False, 0, 1, "y", "n"] = False, + ): + """ + Test function + + Args: + temperature_format: The temperature format to use + booleanish: A value that can be regarded as boolean + + + Returns: + The temperature + """ + return -40.0 + + # Let's see if that gets correctly parsed as an enum + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "temperature_format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature format to use", + }, + "booleanish": { + "type": ["boolean", "integer", "string"], + "enum": [True, False, 0, 1, "y", "n"], + "description": "A value that can be regarded as boolean", + }, + }, + "required": ["temperature_format"], + }, + } + + self.assertEqual(schema["function"], expected_schema) + def test_multiline_docstring_with_types(self): def fn(x: int, y: int): """