Chat Template support for function calling and RAG (#30621)
* First draft, still missing automatic function conversion * First draft of the automatic schema generator * Lots of small fixes * the walrus has betrayed me * please stop committing your debug breakpoints * Lots of cleanup and edge cases, looking better now * Comments and bugfixes for the type hint parser * More cleanup * Add tests, update schema generator * Update tests, proper handling of return values * Small docstring change * More doc updates * More doc updates * Add json_schema decorator * Clean up the TODOs and finish the docs * self.maxDiff = None to see the whole diff for the nested list test * add import for add_json_schema * Quick test fix * Fix something that was bugging me in the chat template docstring * Less "anyOf" when unnecessary * Support return types for the templates that need them * Proper return type tests * Switch to Google format docstrings * Update chat templating docs to match new format * Stop putting the return type in with the other parameters * Add Tuple support * No more decorator - we just do it implicitly! * Add enum support to get_json_schema * Update docstring * Add copyright header * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add copyright header * make fixup * Fix indentation * Reformat chat_template_utils * Correct return value * Make regexes module-level * Support more complex, multi-line arg docstrings * Update error message for ... * Update ruff * Add document type validation * Refactor docs * Refactor docs * Refactor docs * Clean up Tuple error * Add an extra test for very complex defs and docstrings and clean everything up for it * Document enum block * Quick test fixes * Stop supporting type hints in docstring to fix bugs and simplify the regex * Update docs for the regex change * Clean up enum regex * Wrap functions in {"type": "function", "function": ...} * Update src/transformers/utils/chat_template_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Temporary tool calling commit * Add type hints to chat template utils, partially update docs (incomplete!) * Code cleanup based on @molbap's suggestion * Add comments to explain regexes * Fix up type parsing for unions and lists * Add custom exception types and adjust tests to look for them * Update docs with a demo! * Docs cleanup * Pass content as string * Update tool call formatting * Update docs with new function format * Update docs * Update docs with a second tool to show the model choosing correctly --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
@@ -233,6 +233,332 @@ The sun.</s>
|
||||
|
||||
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
|
||||
|
||||
## Advanced: Extra inputs to chat templates
|
||||
|
||||
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
||||
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
||||
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
||||
strings, lists, dicts or whatever else you want.
|
||||
|
||||
That said, there are some common use-cases for these extra arguments,
|
||||
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
||||
we have some opinionated recommendations about what the names and formats of these arguments should be, which are
|
||||
described in the sections below. We encourage model authors to make their chat templates compatible with this format,
|
||||
to make it easy to transfer tool-calling code between models.
|
||||
|
||||
## Advanced: Tool use / function calling
|
||||
|
||||
"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
|
||||
to a tool-use model, you can simply pass a list of functions to the `tools` argument:
|
||||
|
||||
```python
|
||||
import datetime
|
||||
|
||||
def current_time():
|
||||
"""Get the current local time as a string."""
|
||||
return str(datetime.now())
|
||||
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
tools = [current_time, multiply]
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
|
||||
correctly as tools. Specifically, you should follow these rules:
|
||||
|
||||
- The function should have a descriptive name
|
||||
- Every argument must have a type hint
|
||||
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
|
||||
`a (int): The first number to multiply`. Type hints should go in the function header instead.
|
||||
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
||||
because most tool-use models ignore them.
|
||||
|
||||
### Passing tool results to the model
|
||||
|
||||
The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use
|
||||
one? If that happens, you should:
|
||||
|
||||
1. Parse the model's output to get the tool name(s) and arguments.
|
||||
2. Add the model's tool call(s) to the conversation.
|
||||
3. Call the corresponding function(s) with those arguments.
|
||||
4. Add the result(s) to the conversation
|
||||
|
||||
### A complete tool use example
|
||||
|
||||
Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model,
|
||||
as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the
|
||||
memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
|
||||
or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use
|
||||
and offer even stronger performance.
|
||||
|
||||
First, let's load our model and tokenizer:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13")
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
```
|
||||
|
||||
Next, let's define a list of tools:
|
||||
|
||||
```python
|
||||
def get_current_temperature(location: str, unit: str) -> float:
|
||||
"""
|
||||
Get the current temperature at a location.
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, Country"
|
||||
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
|
||||
Returns:
|
||||
The current temperature at the specified location in the specified units, as a float.
|
||||
"""
|
||||
return 22. # A real function should probably actually get the temperature!
|
||||
|
||||
def get_current_wind_speed(location: str) -> float:
|
||||
"""
|
||||
Get the current wind speed in km/h at a given location.
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature for, in the format "City, Country"
|
||||
Returns:
|
||||
The current wind speed at the given location in km/h, as a float.
|
||||
"""
|
||||
return 6. # A real function should probably actually get the wind speed!
|
||||
|
||||
tools = [get_current_temperature, get_current_wind_speed]
|
||||
```
|
||||
|
||||
Now, let's set up a conversation for our bot:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."},
|
||||
{"role": "user", "content": "Hey, what's the temperature in Paris right now?"}
|
||||
]
|
||||
```
|
||||
|
||||
Now, let's apply the chat template and generate a response:
|
||||
|
||||
```python
|
||||
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
out = model.generate(**inputs, max_new_tokens=128)
|
||||
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
|
||||
```
|
||||
|
||||
And we get:
|
||||
|
||||
```text
|
||||
<tool_call>
|
||||
{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"}
|
||||
</tool_call><|im_end|>
|
||||
```
|
||||
|
||||
The model has called the function with valid arguments, in the format requested by the function docstring. It has
|
||||
inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units,
|
||||
the temperature in France should certainly be displayed in Celsius.
|
||||
|
||||
Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs
|
||||
are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response
|
||||
corresponds to which call. You can generate them any way you like, but they should be unique within each chat.
|
||||
|
||||
```python
|
||||
tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call
|
||||
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
|
||||
messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]})
|
||||
```
|
||||
|
||||
|
||||
Now that we've added the tool call to the conversation, we can call the function and append the result to the
|
||||
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
||||
that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above.
|
||||
|
||||
```python
|
||||
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"})
|
||||
```
|
||||
|
||||
Finally, let's let the assistant read the function outputs and continue chatting with the user:
|
||||
|
||||
```python
|
||||
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
out = model.generate(**inputs, max_new_tokens=128)
|
||||
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
|
||||
```
|
||||
|
||||
And we get:
|
||||
|
||||
```text
|
||||
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
|
||||
```
|
||||
|
||||
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
||||
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
|
||||
agents with real-time information, computational tools like calculators, or access to large databases.
|
||||
|
||||
<Tip>
|
||||
Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and
|
||||
match tool calls to results using the ordering, and there are several models that use neither and only issue one tool
|
||||
call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we
|
||||
recommend structuring your tools calls like we've shown here, and returning tool results in the order that
|
||||
they were issued by the model. The chat templates on each model should handle the rest.
|
||||
</Tip>
|
||||
|
||||
### Understanding tool schemas
|
||||
|
||||
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
|
||||
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
||||
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
||||
need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
|
||||
to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
|
||||
return the response in the chat.
|
||||
|
||||
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
||||
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
||||
|
||||
```python
|
||||
from transformers.utils import get_json_schema
|
||||
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
schema = get_json_schema(multiply)
|
||||
print(schema)
|
||||
```
|
||||
|
||||
This will yield:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "number",
|
||||
"description": "The first number to multiply"
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "The second number to multiply"
|
||||
}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||
all. JSON schemas can be passed directly to the `tools` argument of
|
||||
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
||||
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||
to a minimum.
|
||||
|
||||
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
||||
|
||||
```python
|
||||
# A simple function that takes no arguments
|
||||
current_time = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "current_time",
|
||||
"description": "Get the current local time as a string.",
|
||||
"parameters": {
|
||||
'type': 'object',
|
||||
'properties': {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# A more complete function that takes two numerical arguments
|
||||
multiply = {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'multiply',
|
||||
'description': 'A function that multiplies two numbers',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'a': {
|
||||
'type': 'number',
|
||||
'description': 'The first number to multiply'
|
||||
},
|
||||
'b': {
|
||||
'type': 'number', 'description': 'The second number to multiply'
|
||||
}
|
||||
},
|
||||
'required': ['a', 'b']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools = [current_time, multiply]
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced: Retrieval-augmented generation
|
||||
|
||||
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
||||
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||
recommendation for RAG models is that their template
|
||||
should accept a `documents` argument. This should be a list of documents, where each "document"
|
||||
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
||||
than the JSON schemas used for tools, no helper functions are necessary.
|
||||
|
||||
Here's an example of a RAG template in action:
|
||||
|
||||
```python
|
||||
document1 = {
|
||||
"title": "The Moon: Our Age-Old Foe",
|
||||
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
||||
}
|
||||
|
||||
document2 = {
|
||||
"title": "The Sun: Our Age-Old Friend",
|
||||
"contents": "Although often underappreciated, the sun provides several notable benefits..."
|
||||
}
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
documents=[document1, document2]
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced: How do chat templates work?
|
||||
|
||||
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
|
||||
|
||||
@@ -28,6 +28,7 @@ from collections.abc import Mapping, Sized
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from inspect import isfunction
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -47,6 +48,7 @@ from .utils import (
|
||||
copy_func,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
get_json_schema,
|
||||
is_flax_available,
|
||||
is_jax_tensor,
|
||||
is_mlx_available,
|
||||
@@ -1683,6 +1685,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||
tools: Optional[List[Dict]] = None,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = False,
|
||||
tokenize: bool = True,
|
||||
@@ -1703,8 +1707,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
Args:
|
||||
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||
this is not passed, the model's default chat template will be used instead.
|
||||
tools (`List[Dict]`, *optional*):
|
||||
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
||||
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
||||
giving the name, description and argument types for the tool. See our
|
||||
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
|
||||
for more information.
|
||||
documents (`List[Dict[str, str]]`, *optional*):
|
||||
A list of dicts representing documents that will be accessible to the model if it is performing RAG
|
||||
(retrieval-augmented generation). If the template does not support RAG, this argument will have no
|
||||
effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
|
||||
see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
|
||||
for examples of passing documents with chat templates.
|
||||
chat_template (`str`, *optional*):
|
||||
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
||||
argument, as the model's template will be used by default.
|
||||
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
||||
the start of an assistant message. This is useful when you want to generate a response from the model.
|
||||
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||
@@ -1802,6 +1819,27 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
conversations = [conversation]
|
||||
is_batched = False
|
||||
|
||||
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
||||
if tools is not None:
|
||||
tool_schemas = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_schemas.append(tool)
|
||||
elif isfunction(tool):
|
||||
tool_schemas.append(get_json_schema(tool))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Tools should either be a JSON schema, or a callable function with type hints "
|
||||
"and a docstring suitable for auto-conversion to a schema."
|
||||
)
|
||||
else:
|
||||
tool_schemas = None
|
||||
|
||||
if documents is not None:
|
||||
for document in documents:
|
||||
if not isinstance(document, dict):
|
||||
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
|
||||
|
||||
rendered = []
|
||||
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||
for chat in conversations:
|
||||
@@ -1809,7 +1847,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# Indicates it's a Conversation object
|
||||
chat = chat.messages
|
||||
rendered_chat = compiled_template.render(
|
||||
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
|
||||
messages=chat,
|
||||
tools=tool_schemas,
|
||||
documents=documents,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**template_kwargs,
|
||||
)
|
||||
rendered.append(rendered_chat)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
||||
from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from .doc import (
|
||||
add_code_sample_docstrings,
|
||||
|
||||
316
src/transformers/utils/chat_template_utils.py
Normal file
316
src/transformers/utils/chat_template_utils.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
|
||||
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
||||
# Extracts the initial segment of the docstring, containing the function description
|
||||
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
|
||||
# Extracts the Args: block from the docstring
|
||||
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
||||
# Splits the Args: block into individual arguments
|
||||
args_split_re = re.compile(
|
||||
r"""
|
||||
(?:^|\n) # Match the start of the args block, or a newline
|
||||
\s*(\w+):\s* # Capture the argument name and strip spacing
|
||||
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
||||
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
|
||||
""",
|
||||
re.DOTALL | re.VERBOSE,
|
||||
)
|
||||
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
|
||||
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
|
||||
|
||||
|
||||
class TypeHintParsingException(Exception):
|
||||
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DocstringParsingException(Exception):
|
||||
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
type_mapping = {
|
||||
int: {"type": "integer"},
|
||||
float: {"type": "number"},
|
||||
str: {"type": "string"},
|
||||
bool: {"type": "boolean"},
|
||||
Any: {},
|
||||
}
|
||||
return type_mapping.get(param_type, {"type": "object"})
|
||||
|
||||
|
||||
def _parse_type_hint(hint: str) -> Dict:
|
||||
origin = get_origin(hint)
|
||||
args = get_args(hint)
|
||||
|
||||
if origin is None:
|
||||
try:
|
||||
return _get_json_schema_type(hint)
|
||||
except KeyError:
|
||||
raise TypeHintParsingException(
|
||||
"Couldn't parse this type hint, likely due to a custom class or object: ", hint
|
||||
)
|
||||
|
||||
elif origin is Union:
|
||||
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
||||
subtypes = [_parse_type_hint(t) for t in args if t != type(None)]
|
||||
if len(subtypes) == 1:
|
||||
# A single non-null type can be expressed directly
|
||||
return_dict = subtypes[0]
|
||||
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
||||
# A union of basic types can be expressed as a list in the schema
|
||||
return_dict = {"type": [subtype["type"] for subtype in subtypes]}
|
||||
else:
|
||||
# A union of more complex types requires "anyOf"
|
||||
return_dict = {"anyOf": subtypes}
|
||||
if type(None) in args:
|
||||
return_dict["nullable"] = True
|
||||
return return_dict
|
||||
|
||||
elif origin is list:
|
||||
if not args:
|
||||
return {"type": "array"}
|
||||
else:
|
||||
# Lists can only have a single type argument, so recurse into it
|
||||
return {"type": "array", "items": _parse_type_hint(args[0])}
|
||||
|
||||
elif origin is tuple:
|
||||
if not args:
|
||||
return {"type": "array"}
|
||||
if len(args) == 1:
|
||||
raise TypeHintParsingException(
|
||||
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
||||
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
||||
"more than one element, we recommend "
|
||||
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
||||
"pass the element directly."
|
||||
)
|
||||
if ... in args:
|
||||
raise TypeHintParsingException(
|
||||
"Conversion of '...' is not supported in Tuple type hints. "
|
||||
"Use List[] types for variable-length"
|
||||
" inputs instead."
|
||||
)
|
||||
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
||||
|
||||
elif origin is dict:
|
||||
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
||||
# However, we can specify the type of the dict values with "additionalProperties"
|
||||
out = {"type": "object"}
|
||||
if len(args) == 2:
|
||||
out["additionalProperties"] = _parse_type_hint(args[1])
|
||||
return out
|
||||
|
||||
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
required = []
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
properties = {}
|
||||
for param_name, param_type in type_hints.items():
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
|
||||
"""
|
||||
Parses a Google-style docstring to extract the function description,
|
||||
argument descriptions, and return description.
|
||||
|
||||
Args:
|
||||
docstring (str): The docstring to parse.
|
||||
|
||||
Returns:
|
||||
The function description, arguments, and return description.
|
||||
"""
|
||||
|
||||
# Extract the sections
|
||||
description_match = description_re.search(docstring)
|
||||
args_match = args_re.search(docstring)
|
||||
returns_match = returns_re.search(docstring)
|
||||
|
||||
# Clean and store the sections
|
||||
description = description_match.group(1).strip() if description_match else None
|
||||
docstring_args = args_match.group(1).strip() if args_match else None
|
||||
returns = returns_match.group(1).strip() if returns_match else None
|
||||
|
||||
# Parsing the arguments into a dictionary
|
||||
if docstring_args is not None:
|
||||
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
||||
matches = args_split_re.findall(docstring_args)
|
||||
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
||||
else:
|
||||
args_dict = {}
|
||||
|
||||
return description, args_dict, returns
|
||||
|
||||
|
||||
def get_json_schema(func: Callable) -> Dict:
|
||||
"""
|
||||
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
|
||||
that the function has a docstring, and that each argument has a description in the docstring, in the standard
|
||||
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
|
||||
|
||||
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
|
||||
optional because most chat templates ignore the return value of the function.
|
||||
|
||||
Args:
|
||||
func: The function to generate a JSON schema for.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the JSON schema for the function.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> def multiply(x: float, y: float):
|
||||
>>> '''
|
||||
>>> A function that multiplies two numbers
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> x: The first number to multiply
|
||||
>>> y: The second number to multiply
|
||||
>>> '''
|
||||
>>> return x * y
|
||||
>>>
|
||||
>>> print(get_json_schema(multiply))
|
||||
{
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "number", "description": "The first number to multiply"},
|
||||
"y": {"type": "number", "description": "The second number to multiply"}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
|
||||
support them, like so:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> from transformers.utils import get_json_schema
|
||||
>>>
|
||||
>>> def multiply(x: float, y: float):
|
||||
>>> '''
|
||||
>>> A function that multiplies two numbers
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> x: The first number to multiply
|
||||
>>> y: The second number to multiply
|
||||
>>> return x * y
|
||||
>>> '''
|
||||
>>>
|
||||
>>> multiply_schema = get_json_schema(multiply)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
|
||||
>>> formatted_chat = tokenizer.apply_chat_template(
|
||||
>>> messages,
|
||||
>>> tools=[multiply_schema],
|
||||
>>> chat_template="tool_use",
|
||||
>>> return_dict=True,
|
||||
>>> return_tensors="pt",
|
||||
>>> add_generation_prompt=True
|
||||
>>> )
|
||||
>>> # The formatted chat can now be passed to model.generate()
|
||||
```
|
||||
|
||||
Each argument description can also have an optional `(choices: ...)` block at the end, such as
|
||||
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
|
||||
only be parsed correctly if it is at the end of the line:
|
||||
|
||||
```python
|
||||
>>> def drink_beverage(beverage: str):
|
||||
>>> '''
|
||||
>>> A function that drinks a beverage
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
|
||||
>>> '''
|
||||
>>> pass
|
||||
>>>
|
||||
>>> print(get_json_schema(drink_beverage))
|
||||
```
|
||||
{
|
||||
'name': 'drink_beverage',
|
||||
'description': 'A function that drinks a beverage',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'beverage': {
|
||||
'type': 'string',
|
||||
'enum': ['tea', 'coffee'],
|
||||
'description': 'The beverage to drink'
|
||||
}
|
||||
},
|
||||
'required': ['beverage']
|
||||
}
|
||||
}
|
||||
"""
|
||||
doc = inspect.getdoc(func)
|
||||
if not doc:
|
||||
raise DocstringParsingException(
|
||||
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
|
||||
)
|
||||
doc = doc.strip()
|
||||
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
|
||||
|
||||
json_schema = _convert_type_hints_to_json_schema(func)
|
||||
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
||||
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
||||
return_dict["description"] = return_doc
|
||||
for arg, schema in json_schema["properties"].items():
|
||||
if arg not in param_descriptions:
|
||||
raise DocstringParsingException(
|
||||
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
||||
)
|
||||
desc = param_descriptions[arg]
|
||||
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
||||
if enum_choices:
|
||||
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
||||
desc = enum_choices.string[: enum_choices.start()].strip()
|
||||
schema["description"] = desc
|
||||
|
||||
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
||||
if return_dict is not None:
|
||||
output["return"] = return_dict
|
||||
return {"type": "function", "function": output}
|
||||
476
tests/utils/test_chat_template_utils.py
Normal file
476
tests/utils/test_chat_template_utils.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
|
||||
|
||||
class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
def test_simple_function(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_no_arguments(self):
|
||||
def fn():
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return True
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_union(self):
|
||||
def fn(x: Union[int, float]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_optional(self):
|
||||
def fn(x: Optional[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_default_arg(self):
|
||||
def fn(x: int = 42):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_nested_list(self):
|
||||
def fn(x: List[List[Union[str, int]]]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"items": {"type": "array", "items": {"type": ["string", "integer"]}},
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_arguments(self):
|
||||
def fn(x: int, y: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The input"},
|
||||
"y": {"type": "string", "description": "Also the input"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_complex_arguments(self):
|
||||
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
|
||||
"y": {
|
||||
"type": ["integer", "string"],
|
||||
"nullable": True,
|
||||
"description": "Also the input",
|
||||
},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_missing_docstring(self):
|
||||
def fn(x: int):
|
||||
return x
|
||||
|
||||
with self.assertRaises(DocstringParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_param_docstring(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(DocstringParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_type_hint(self):
|
||||
def fn(x):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_return_value(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_return_value_docstring(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer", "description": "The output"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_tuple(self):
|
||||
def fn(x: Tuple[int, str]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_single_element_tuple_fails(self):
|
||||
def fn(x: Tuple[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_ellipsis_type_fails(self):
|
||||
def fn(x: Tuple[int, ...]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_enum_extraction(self):
|
||||
def fn(temperature_format: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
|
||||
|
||||
|
||||
Returns:
|
||||
The temperature
|
||||
"""
|
||||
return -40.0
|
||||
|
||||
# Let's see if that gets correctly parsed as an enum
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature_format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature format to use",
|
||||
}
|
||||
},
|
||||
"required": ["temperature_format"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiline_docstring_with_types(self):
|
||||
def fn(x: int, y: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The first input
|
||||
|
||||
y: The second input. This is a longer description
|
||||
that spans multiple lines with indentation and stuff.
|
||||
|
||||
Returns:
|
||||
God knows what
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The first input"},
|
||||
"y": {
|
||||
"type": "integer",
|
||||
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_everything_all_at_once(self):
|
||||
def fn(
|
||||
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
|
||||
) -> Tuple[int, str]:
|
||||
"""
|
||||
Test function with multiple args, and docstring args that we have to strip out.
|
||||
|
||||
Args:
|
||||
x: The first input. It's got a big multiline
|
||||
description and also contains
|
||||
(choices: ["a", "b", "c"])
|
||||
|
||||
y: The second input. It's a big list with a single-line description.
|
||||
|
||||
z: The third input. It's some kind of tuple with a default arg.
|
||||
|
||||
Returns:
|
||||
The output. The return description is also a big multiline
|
||||
description that spans multiple lines.
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function with multiple args, and docstring args that we have to strip out.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b", "c"],
|
||||
"description": "The first input. It's got a big multiline description and also contains",
|
||||
},
|
||||
"y": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
"nullable": True,
|
||||
"description": "The second input. It's a big list with a single-line description.",
|
||||
},
|
||||
"z": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}],
|
||||
"description": "The third input. It's some kind of tuple with a default arg.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
"return": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
Reference in New Issue
Block a user