Chat Template support for function calling and RAG (#30621)
* First draft, still missing automatic function conversion * First draft of the automatic schema generator * Lots of small fixes * the walrus has betrayed me * please stop committing your debug breakpoints * Lots of cleanup and edge cases, looking better now * Comments and bugfixes for the type hint parser * More cleanup * Add tests, update schema generator * Update tests, proper handling of return values * Small docstring change * More doc updates * More doc updates * Add json_schema decorator * Clean up the TODOs and finish the docs * self.maxDiff = None to see the whole diff for the nested list test * add import for add_json_schema * Quick test fix * Fix something that was bugging me in the chat template docstring * Less "anyOf" when unnecessary * Support return types for the templates that need them * Proper return type tests * Switch to Google format docstrings * Update chat templating docs to match new format * Stop putting the return type in with the other parameters * Add Tuple support * No more decorator - we just do it implicitly! * Add enum support to get_json_schema * Update docstring * Add copyright header * Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/chat_templating.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/chat_template_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add copyright header * make fixup * Fix indentation * Reformat chat_template_utils * Correct return value * Make regexes module-level * Support more complex, multi-line arg docstrings * Update error message for ... * Update ruff * Add document type validation * Refactor docs * Refactor docs * Refactor docs * Clean up Tuple error * Add an extra test for very complex defs and docstrings and clean everything up for it * Document enum block * Quick test fixes * Stop supporting type hints in docstring to fix bugs and simplify the regex * Update docs for the regex change * Clean up enum regex * Wrap functions in {"type": "function", "function": ...} * Update src/transformers/utils/chat_template_utils.py Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * Temporary tool calling commit * Add type hints to chat template utils, partially update docs (incomplete!) * Code cleanup based on @molbap's suggestion * Add comments to explain regexes * Fix up type parsing for unions and lists * Add custom exception types and adjust tests to look for them * Update docs with a demo! * Docs cleanup * Pass content as string * Update tool call formatting * Update docs with new function format * Update docs * Update docs with a second tool to show the model choosing correctly --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
@@ -233,6 +233,332 @@ The sun.</s>
|
|||||||
|
|
||||||
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
|
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
|
||||||
|
|
||||||
|
## Advanced: Extra inputs to chat templates
|
||||||
|
|
||||||
|
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
||||||
|
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
||||||
|
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
||||||
|
strings, lists, dicts or whatever else you want.
|
||||||
|
|
||||||
|
That said, there are some common use-cases for these extra arguments,
|
||||||
|
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
||||||
|
we have some opinionated recommendations about what the names and formats of these arguments should be, which are
|
||||||
|
described in the sections below. We encourage model authors to make their chat templates compatible with this format,
|
||||||
|
to make it easy to transfer tool-calling code between models.
|
||||||
|
|
||||||
|
## Advanced: Tool use / function calling
|
||||||
|
|
||||||
|
"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
|
||||||
|
to a tool-use model, you can simply pass a list of functions to the `tools` argument:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def current_time():
|
||||||
|
"""Get the current local time as a string."""
|
||||||
|
return str(datetime.now())
|
||||||
|
|
||||||
|
def multiply(a: float, b: float):
|
||||||
|
"""
|
||||||
|
A function that multiplies two numbers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: The first number to multiply
|
||||||
|
b: The second number to multiply
|
||||||
|
"""
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
tools = [current_time, multiply]
|
||||||
|
|
||||||
|
model_input = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tools=tools
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
|
||||||
|
correctly as tools. Specifically, you should follow these rules:
|
||||||
|
|
||||||
|
- The function should have a descriptive name
|
||||||
|
- Every argument must have a type hint
|
||||||
|
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||||
|
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||||
|
- Do not include types in the `Args:` block. In other words, write `a: The first number to multiply`, not
|
||||||
|
`a (int): The first number to multiply`. Type hints should go in the function header instead.
|
||||||
|
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
||||||
|
because most tool-use models ignore them.
|
||||||
|
|
||||||
|
### Passing tool results to the model
|
||||||
|
|
||||||
|
The sample code above is enough to list the available tools for your model, but what happens if it wants to actually use
|
||||||
|
one? If that happens, you should:
|
||||||
|
|
||||||
|
1. Parse the model's output to get the tool name(s) and arguments.
|
||||||
|
2. Add the model's tool call(s) to the conversation.
|
||||||
|
3. Call the corresponding function(s) with those arguments.
|
||||||
|
4. Add the result(s) to the conversation
|
||||||
|
|
||||||
|
### A complete tool use example
|
||||||
|
|
||||||
|
Let's walk through a tool use example, step by step. For this example, we will use an 8B `Hermes-2-Pro` model,
|
||||||
|
as it is one of the highest-performing tool-use models in its size category at the time of writing. If you have the
|
||||||
|
memory, you can consider using a larger model instead like [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
|
||||||
|
or [Mixtral-8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1), both of which also support tool use
|
||||||
|
and offer even stronger performance.
|
||||||
|
|
||||||
|
First, let's load our model and tokenizer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
checkpoint = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint, revision="pr/13")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto")
|
||||||
|
```
|
||||||
|
|
||||||
|
Next, let's define a list of tools:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_current_temperature(location: str, unit: str) -> float:
|
||||||
|
"""
|
||||||
|
Get the current temperature at a location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The location to get the temperature for, in the format "City, Country"
|
||||||
|
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
|
||||||
|
Returns:
|
||||||
|
The current temperature at the specified location in the specified units, as a float.
|
||||||
|
"""
|
||||||
|
return 22. # A real function should probably actually get the temperature!
|
||||||
|
|
||||||
|
def get_current_wind_speed(location: str) -> float:
|
||||||
|
"""
|
||||||
|
Get the current wind speed in km/h at a given location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The location to get the temperature for, in the format "City, Country"
|
||||||
|
Returns:
|
||||||
|
The current wind speed at the given location in km/h, as a float.
|
||||||
|
"""
|
||||||
|
return 6. # A real function should probably actually get the wind speed!
|
||||||
|
|
||||||
|
tools = [get_current_temperature, get_current_wind_speed]
|
||||||
|
```
|
||||||
|
|
||||||
|
Now, let's set up a conversation for our bot:
|
||||||
|
|
||||||
|
```python
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location."},
|
||||||
|
{"role": "user", "content": "Hey, what's the temperature in Paris right now?"}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Now, let's apply the chat template and generate a response:
|
||||||
|
|
||||||
|
```python
|
||||||
|
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
out = model.generate(**inputs, max_new_tokens=128)
|
||||||
|
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
|
||||||
|
```
|
||||||
|
|
||||||
|
And we get:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<tool_call>
|
||||||
|
{"arguments": {"location": "Paris, France", "unit": "celsius"}, "name": "get_current_temperature"}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
```
|
||||||
|
|
||||||
|
The model has called the function with valid arguments, in the format requested by the function docstring. It has
|
||||||
|
inferred that we're most likely referring to the Paris in France, and it remembered that, as the home of SI units,
|
||||||
|
the temperature in France should certainly be displayed in Celsius.
|
||||||
|
|
||||||
|
Let's append the model's tool call to the conversation. Note that we generate a random `tool_call_id` here. These IDs
|
||||||
|
are not used by all models, but they allow models to issue multiple tool calls at once and keep track of which response
|
||||||
|
corresponds to which call. You can generate them any way you like, but they should be unique within each chat.
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool_call_id = "vAHdf3" # Random ID, should be unique for each tool call
|
||||||
|
tool_call = {"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}}
|
||||||
|
messages.append({"role": "assistant", "tool_calls": [{"id": tool_call_id, "type": "function", "function": tool_call}]})
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Now that we've added the tool call to the conversation, we can call the function and append the result to the
|
||||||
|
conversation. Since we're just using a dummy function for this example that always returns 22.0, we can just append
|
||||||
|
that result directly. Again, note the `tool_call_id` - this should match the ID used in the tool call above.
|
||||||
|
|
||||||
|
```python
|
||||||
|
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": "get_current_temperature", "content": "22.0"})
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, let's let the assistant read the function outputs and continue chatting with the user:
|
||||||
|
|
||||||
|
```python
|
||||||
|
inputs = tokenizer.apply_chat_template(messages, chat_template="tool_use", tools=tools, add_generation_prompt=True, return_dict=True, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||||
|
out = model.generate(**inputs, max_new_tokens=128)
|
||||||
|
print(tokenizer.decode(out[0][len(inputs["input_ids"][0]):]))
|
||||||
|
```
|
||||||
|
|
||||||
|
And we get:
|
||||||
|
|
||||||
|
```text
|
||||||
|
The current temperature in Paris, France is 22.0 ° Celsius.<|im_end|>
|
||||||
|
```
|
||||||
|
|
||||||
|
Although this was a simple demo with dummy tools and a single call, the same technique works with
|
||||||
|
multiple real tools and longer conversations. This can be a powerful way to extend the capabilities of conversational
|
||||||
|
agents with real-time information, computational tools like calculators, or access to large databases.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
Not all of the tool-calling features shown above are used by all models. Some use tool call IDs, others simply use the function name and
|
||||||
|
match tool calls to results using the ordering, and there are several models that use neither and only issue one tool
|
||||||
|
call at a time to avoid confusion. If you want your code to be compatible across as many models as possible, we
|
||||||
|
recommend structuring your tools calls like we've shown here, and returning tool results in the order that
|
||||||
|
they were issued by the model. The chat templates on each model should handle the rest.
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
### Understanding tool schemas
|
||||||
|
|
||||||
|
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||||
|
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step). These schemas
|
||||||
|
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
||||||
|
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
||||||
|
need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
|
||||||
|
to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
|
||||||
|
return the response in the chat.
|
||||||
|
|
||||||
|
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
||||||
|
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||||
|
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers.utils import get_json_schema
|
||||||
|
|
||||||
|
def multiply(a: float, b: float):
|
||||||
|
"""
|
||||||
|
A function that multiplies two numbers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: The first number to multiply
|
||||||
|
b: The second number to multiply
|
||||||
|
"""
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
schema = get_json_schema(multiply)
|
||||||
|
print(schema)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will yield:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "A function that multiplies two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The first number to multiply"
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The second number to multiply"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["a", "b"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||||
|
all. JSON schemas can be passed directly to the `tools` argument of
|
||||||
|
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
||||||
|
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||||
|
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||||
|
to a minimum.
|
||||||
|
|
||||||
|
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# A simple function that takes no arguments
|
||||||
|
current_time = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "current_time",
|
||||||
|
"description": "Get the current local time as a string.",
|
||||||
|
"parameters": {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# A more complete function that takes two numerical arguments
|
||||||
|
multiply = {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': 'multiply',
|
||||||
|
'description': 'A function that multiplies two numbers',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'a': {
|
||||||
|
'type': 'number',
|
||||||
|
'description': 'The first number to multiply'
|
||||||
|
},
|
||||||
|
'b': {
|
||||||
|
'type': 'number', 'description': 'The second number to multiply'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'required': ['a', 'b']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
model_input = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tools = [current_time, multiply]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced: Retrieval-augmented generation
|
||||||
|
|
||||||
|
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
||||||
|
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||||
|
recommendation for RAG models is that their template
|
||||||
|
should accept a `documents` argument. This should be a list of documents, where each "document"
|
||||||
|
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
||||||
|
than the JSON schemas used for tools, no helper functions are necessary.
|
||||||
|
|
||||||
|
Here's an example of a RAG template in action:
|
||||||
|
|
||||||
|
```python
|
||||||
|
document1 = {
|
||||||
|
"title": "The Moon: Our Age-Old Foe",
|
||||||
|
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
||||||
|
}
|
||||||
|
|
||||||
|
document2 = {
|
||||||
|
"title": "The Sun: Our Age-Old Friend",
|
||||||
|
"contents": "Although often underappreciated, the sun provides several notable benefits..."
|
||||||
|
}
|
||||||
|
|
||||||
|
model_input = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
documents=[document1, document2]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Advanced: How do chat templates work?
|
## Advanced: How do chat templates work?
|
||||||
|
|
||||||
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
|
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from collections.abc import Mapping, Sized
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from inspect import isfunction
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -47,6 +48,7 @@ from .utils import (
|
|||||||
copy_func,
|
copy_func,
|
||||||
download_url,
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
|
get_json_schema,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_jax_tensor,
|
is_jax_tensor,
|
||||||
is_mlx_available,
|
is_mlx_available,
|
||||||
@@ -1683,6 +1685,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
def apply_chat_template(
|
def apply_chat_template(
|
||||||
self,
|
self,
|
||||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
documents: Optional[List[Dict[str, str]]] = None,
|
||||||
chat_template: Optional[str] = None,
|
chat_template: Optional[str] = None,
|
||||||
add_generation_prompt: bool = False,
|
add_generation_prompt: bool = False,
|
||||||
tokenize: bool = True,
|
tokenize: bool = True,
|
||||||
@@ -1703,8 +1707,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
Args:
|
Args:
|
||||||
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
|
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
|
||||||
with "role" and "content" keys, representing the chat history so far.
|
with "role" and "content" keys, representing the chat history so far.
|
||||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
tools (`List[Dict]`, *optional*):
|
||||||
this is not passed, the model's default chat template will be used instead.
|
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
||||||
|
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
||||||
|
giving the name, description and argument types for the tool. See our
|
||||||
|
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
|
||||||
|
for more information.
|
||||||
|
documents (`List[Dict[str, str]]`, *optional*):
|
||||||
|
A list of dicts representing documents that will be accessible to the model if it is performing RAG
|
||||||
|
(retrieval-augmented generation). If the template does not support RAG, this argument will have no
|
||||||
|
effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
|
||||||
|
see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
|
||||||
|
for examples of passing documents with chat templates.
|
||||||
|
chat_template (`str`, *optional*):
|
||||||
|
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
||||||
|
argument, as the model's template will be used by default.
|
||||||
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
||||||
the start of an assistant message. This is useful when you want to generate a response from the model.
|
the start of an assistant message. This is useful when you want to generate a response from the model.
|
||||||
Note that this argument will be passed to the chat template, and so it must be supported in the
|
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||||
@@ -1802,6 +1819,27 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
conversations = [conversation]
|
conversations = [conversation]
|
||||||
is_batched = False
|
is_batched = False
|
||||||
|
|
||||||
|
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
||||||
|
if tools is not None:
|
||||||
|
tool_schemas = []
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, dict):
|
||||||
|
tool_schemas.append(tool)
|
||||||
|
elif isfunction(tool):
|
||||||
|
tool_schemas.append(get_json_schema(tool))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Tools should either be a JSON schema, or a callable function with type hints "
|
||||||
|
"and a docstring suitable for auto-conversion to a schema."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_schemas = None
|
||||||
|
|
||||||
|
if documents is not None:
|
||||||
|
for document in documents:
|
||||||
|
if not isinstance(document, dict):
|
||||||
|
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
|
||||||
|
|
||||||
rendered = []
|
rendered = []
|
||||||
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||||
for chat in conversations:
|
for chat in conversations:
|
||||||
@@ -1809,7 +1847,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# Indicates it's a Conversation object
|
# Indicates it's a Conversation object
|
||||||
chat = chat.messages
|
chat = chat.messages
|
||||||
rendered_chat = compiled_template.render(
|
rendered_chat = compiled_template.render(
|
||||||
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
|
messages=chat,
|
||||||
|
tools=tool_schemas,
|
||||||
|
documents=documents,
|
||||||
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**template_kwargs,
|
||||||
)
|
)
|
||||||
rendered.append(rendered_chat)
|
rendered.append(rendered_chat)
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
||||||
|
from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||||
from .doc import (
|
from .doc import (
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
|
|||||||
316
src/transformers/utils/chat_template_utils.py
Normal file
316
src/transformers/utils/chat_template_utils.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
|
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
||||||
|
# Extracts the initial segment of the docstring, containing the function description
|
||||||
|
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
|
||||||
|
# Extracts the Args: block from the docstring
|
||||||
|
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
||||||
|
# Splits the Args: block into individual arguments
|
||||||
|
args_split_re = re.compile(
|
||||||
|
r"""
|
||||||
|
(?:^|\n) # Match the start of the args block, or a newline
|
||||||
|
\s*(\w+):\s* # Capture the argument name and strip spacing
|
||||||
|
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
||||||
|
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
|
||||||
|
""",
|
||||||
|
re.DOTALL | re.VERBOSE,
|
||||||
|
)
|
||||||
|
# Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
|
||||||
|
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
class TypeHintParsingException(Exception):
|
||||||
|
"""Exception raised for errors in parsing type hints to generate JSON schemas"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DocstringParsingException(Exception):
|
||||||
|
"""Exception raised for errors in parsing docstrings to generate JSON schemas"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||||
|
type_mapping = {
|
||||||
|
int: {"type": "integer"},
|
||||||
|
float: {"type": "number"},
|
||||||
|
str: {"type": "string"},
|
||||||
|
bool: {"type": "boolean"},
|
||||||
|
Any: {},
|
||||||
|
}
|
||||||
|
return type_mapping.get(param_type, {"type": "object"})
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_type_hint(hint: str) -> Dict:
|
||||||
|
origin = get_origin(hint)
|
||||||
|
args = get_args(hint)
|
||||||
|
|
||||||
|
if origin is None:
|
||||||
|
try:
|
||||||
|
return _get_json_schema_type(hint)
|
||||||
|
except KeyError:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
"Couldn't parse this type hint, likely due to a custom class or object: ", hint
|
||||||
|
)
|
||||||
|
|
||||||
|
elif origin is Union:
|
||||||
|
# Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
|
||||||
|
subtypes = [_parse_type_hint(t) for t in args if t != type(None)]
|
||||||
|
if len(subtypes) == 1:
|
||||||
|
# A single non-null type can be expressed directly
|
||||||
|
return_dict = subtypes[0]
|
||||||
|
elif all(isinstance(subtype["type"], str) for subtype in subtypes):
|
||||||
|
# A union of basic types can be expressed as a list in the schema
|
||||||
|
return_dict = {"type": [subtype["type"] for subtype in subtypes]}
|
||||||
|
else:
|
||||||
|
# A union of more complex types requires "anyOf"
|
||||||
|
return_dict = {"anyOf": subtypes}
|
||||||
|
if type(None) in args:
|
||||||
|
return_dict["nullable"] = True
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
elif origin is list:
|
||||||
|
if not args:
|
||||||
|
return {"type": "array"}
|
||||||
|
else:
|
||||||
|
# Lists can only have a single type argument, so recurse into it
|
||||||
|
return {"type": "array", "items": _parse_type_hint(args[0])}
|
||||||
|
|
||||||
|
elif origin is tuple:
|
||||||
|
if not args:
|
||||||
|
return {"type": "array"}
|
||||||
|
if len(args) == 1:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
||||||
|
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
||||||
|
"more than one element, we recommend "
|
||||||
|
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
||||||
|
"pass the element directly."
|
||||||
|
)
|
||||||
|
if ... in args:
|
||||||
|
raise TypeHintParsingException(
|
||||||
|
"Conversion of '...' is not supported in Tuple type hints. "
|
||||||
|
"Use List[] types for variable-length"
|
||||||
|
" inputs instead."
|
||||||
|
)
|
||||||
|
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
||||||
|
|
||||||
|
elif origin is dict:
|
||||||
|
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
||||||
|
# However, we can specify the type of the dict values with "additionalProperties"
|
||||||
|
out = {"type": "object"}
|
||||||
|
if len(args) == 2:
|
||||||
|
out["additionalProperties"] = _parse_type_hint(args[1])
|
||||||
|
return out
|
||||||
|
|
||||||
|
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||||
|
type_hints = get_type_hints(func)
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
required = []
|
||||||
|
for param_name, param in signature.parameters.items():
|
||||||
|
if param.annotation == inspect.Parameter.empty:
|
||||||
|
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
required.append(param_name)
|
||||||
|
|
||||||
|
properties = {}
|
||||||
|
for param_name, param_type in type_hints.items():
|
||||||
|
properties[param_name] = _parse_type_hint(param_type)
|
||||||
|
|
||||||
|
schema = {"type": "object", "properties": properties}
|
||||||
|
if required:
|
||||||
|
schema["required"] = required
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Parses a Google-style docstring to extract the function description,
|
||||||
|
argument descriptions, and return description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docstring (str): The docstring to parse.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The function description, arguments, and return description.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Extract the sections
|
||||||
|
description_match = description_re.search(docstring)
|
||||||
|
args_match = args_re.search(docstring)
|
||||||
|
returns_match = returns_re.search(docstring)
|
||||||
|
|
||||||
|
# Clean and store the sections
|
||||||
|
description = description_match.group(1).strip() if description_match else None
|
||||||
|
docstring_args = args_match.group(1).strip() if args_match else None
|
||||||
|
returns = returns_match.group(1).strip() if returns_match else None
|
||||||
|
|
||||||
|
# Parsing the arguments into a dictionary
|
||||||
|
if docstring_args is not None:
|
||||||
|
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
||||||
|
matches = args_split_re.findall(docstring_args)
|
||||||
|
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
||||||
|
else:
|
||||||
|
args_dict = {}
|
||||||
|
|
||||||
|
return description, args_dict, returns
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_schema(func: Callable) -> Dict:
|
||||||
|
"""
|
||||||
|
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||||
|
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||||
|
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
|
||||||
|
that the function has a docstring, and that each argument has a description in the docstring, in the standard
|
||||||
|
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
|
||||||
|
|
||||||
|
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
|
||||||
|
optional because most chat templates ignore the return value of the function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to generate a JSON schema for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the JSON schema for the function.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> def multiply(x: float, y: float):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that multiplies two numbers
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> x: The first number to multiply
|
||||||
|
>>> y: The second number to multiply
|
||||||
|
>>> '''
|
||||||
|
>>> return x * y
|
||||||
|
>>>
|
||||||
|
>>> print(get_json_schema(multiply))
|
||||||
|
{
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "A function that multiplies two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "number", "description": "The first number to multiply"},
|
||||||
|
"y": {"type": "number", "description": "The second number to multiply"}
|
||||||
|
},
|
||||||
|
"required": ["x", "y"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
|
||||||
|
support them, like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer
|
||||||
|
>>> from transformers.utils import get_json_schema
|
||||||
|
>>>
|
||||||
|
>>> def multiply(x: float, y: float):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that multiplies two numbers
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> x: The first number to multiply
|
||||||
|
>>> y: The second number to multiply
|
||||||
|
>>> return x * y
|
||||||
|
>>> '''
|
||||||
|
>>>
|
||||||
|
>>> multiply_schema = get_json_schema(multiply)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
|
||||||
|
>>> formatted_chat = tokenizer.apply_chat_template(
|
||||||
|
>>> messages,
|
||||||
|
>>> tools=[multiply_schema],
|
||||||
|
>>> chat_template="tool_use",
|
||||||
|
>>> return_dict=True,
|
||||||
|
>>> return_tensors="pt",
|
||||||
|
>>> add_generation_prompt=True
|
||||||
|
>>> )
|
||||||
|
>>> # The formatted chat can now be passed to model.generate()
|
||||||
|
```
|
||||||
|
|
||||||
|
Each argument description can also have an optional `(choices: ...)` block at the end, such as
|
||||||
|
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
|
||||||
|
only be parsed correctly if it is at the end of the line:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> def drink_beverage(beverage: str):
|
||||||
|
>>> '''
|
||||||
|
>>> A function that drinks a beverage
|
||||||
|
>>>
|
||||||
|
>>> Args:
|
||||||
|
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
|
||||||
|
>>> '''
|
||||||
|
>>> pass
|
||||||
|
>>>
|
||||||
|
>>> print(get_json_schema(drink_beverage))
|
||||||
|
```
|
||||||
|
{
|
||||||
|
'name': 'drink_beverage',
|
||||||
|
'description': 'A function that drinks a beverage',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'beverage': {
|
||||||
|
'type': 'string',
|
||||||
|
'enum': ['tea', 'coffee'],
|
||||||
|
'description': 'The beverage to drink'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'required': ['beverage']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
doc = inspect.getdoc(func)
|
||||||
|
if not doc:
|
||||||
|
raise DocstringParsingException(
|
||||||
|
f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
|
||||||
|
)
|
||||||
|
doc = doc.strip()
|
||||||
|
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
|
||||||
|
|
||||||
|
json_schema = _convert_type_hints_to_json_schema(func)
|
||||||
|
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
||||||
|
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
||||||
|
return_dict["description"] = return_doc
|
||||||
|
for arg, schema in json_schema["properties"].items():
|
||||||
|
if arg not in param_descriptions:
|
||||||
|
raise DocstringParsingException(
|
||||||
|
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
||||||
|
)
|
||||||
|
desc = param_descriptions[arg]
|
||||||
|
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
||||||
|
if enum_choices:
|
||||||
|
schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
||||||
|
desc = enum_choices.string[: enum_choices.start()].strip()
|
||||||
|
schema["description"] = desc
|
||||||
|
|
||||||
|
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
||||||
|
if return_dict is not None:
|
||||||
|
output["return"] = return_dict
|
||||||
|
return {"type": "function", "function": output}
|
||||||
476
tests/utils/test_chat_template_utils.py
Normal file
476
tests/utils/test_chat_template_utils.py
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||||
|
def test_simple_function(self):
|
||||||
|
def fn(x: int):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_no_arguments(self):
|
||||||
|
def fn():
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_union(self):
|
||||||
|
def fn(x: Union[int, float]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_optional(self):
|
||||||
|
def fn(x: Optional[int]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_default_arg(self):
|
||||||
|
def fn(x: int = 42):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_nested_list(self):
|
||||||
|
def fn(x: List[List[Union[str, int]]]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "array", "items": {"type": ["string", "integer"]}},
|
||||||
|
"description": "The input",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_multiple_arguments(self):
|
||||||
|
def fn(x: int, y: str):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
y: Also the input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "integer", "description": "The input"},
|
||||||
|
"y": {"type": "string", "description": "Also the input"},
|
||||||
|
},
|
||||||
|
"required": ["x", "y"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_multiple_complex_arguments(self):
|
||||||
|
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
y: Also the input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
|
||||||
|
"y": {
|
||||||
|
"type": ["integer", "string"],
|
||||||
|
"nullable": True,
|
||||||
|
"description": "Also the input",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_missing_docstring(self):
|
||||||
|
def fn(x: int):
|
||||||
|
return x
|
||||||
|
|
||||||
|
with self.assertRaises(DocstringParsingException):
|
||||||
|
get_json_schema(fn)
|
||||||
|
|
||||||
|
def test_missing_param_docstring(self):
|
||||||
|
def fn(x: int):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
with self.assertRaises(DocstringParsingException):
|
||||||
|
get_json_schema(fn)
|
||||||
|
|
||||||
|
def test_missing_type_hint(self):
|
||||||
|
def fn(x):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
with self.assertRaises(TypeHintParsingException):
|
||||||
|
get_json_schema(fn)
|
||||||
|
|
||||||
|
def test_return_value(self):
|
||||||
|
def fn(x: int) -> int:
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
"return": {"type": "integer"},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_return_value_docstring(self):
|
||||||
|
def fn(x: int) -> int:
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
"return": {"type": "integer", "description": "The output"},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_tuple(self):
|
||||||
|
def fn(x: Tuple[int, str]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "array",
|
||||||
|
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||||
|
"description": "The input",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["x"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_single_element_tuple_fails(self):
|
||||||
|
def fn(x: Tuple[int]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
|
||||||
|
with self.assertRaises(TypeHintParsingException):
|
||||||
|
get_json_schema(fn)
|
||||||
|
|
||||||
|
def test_ellipsis_type_fails(self):
|
||||||
|
def fn(x: Tuple[int, ...]):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
|
||||||
|
with self.assertRaises(TypeHintParsingException):
|
||||||
|
get_json_schema(fn)
|
||||||
|
|
||||||
|
def test_enum_extraction(self):
|
||||||
|
def fn(temperature_format: str):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The temperature
|
||||||
|
"""
|
||||||
|
return -40.0
|
||||||
|
|
||||||
|
# Let's see if that gets correctly parsed as an enum
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"temperature_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"description": "The temperature format to use",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["temperature_format"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_multiline_docstring_with_types(self):
|
||||||
|
def fn(x: int, y: int):
|
||||||
|
"""
|
||||||
|
Test function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The first input
|
||||||
|
|
||||||
|
y: The second input. This is a longer description
|
||||||
|
that spans multiple lines with indentation and stuff.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
God knows what
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "integer", "description": "The first input"},
|
||||||
|
"y": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["x", "y"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
|
|
||||||
|
def test_everything_all_at_once(self):
|
||||||
|
def fn(
|
||||||
|
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
|
||||||
|
) -> Tuple[int, str]:
|
||||||
|
"""
|
||||||
|
Test function with multiple args, and docstring args that we have to strip out.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The first input. It's got a big multiline
|
||||||
|
description and also contains
|
||||||
|
(choices: ["a", "b", "c"])
|
||||||
|
|
||||||
|
y: The second input. It's a big list with a single-line description.
|
||||||
|
|
||||||
|
z: The third input. It's some kind of tuple with a default arg.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output. The return description is also a big multiline
|
||||||
|
description that spans multiple lines.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
schema = get_json_schema(fn)
|
||||||
|
expected_schema = {
|
||||||
|
"name": "fn",
|
||||||
|
"description": "Test function with multiple args, and docstring args that we have to strip out.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["a", "b", "c"],
|
||||||
|
"description": "The first input. It's got a big multiline description and also contains",
|
||||||
|
},
|
||||||
|
"y": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": ["string", "integer"]},
|
||||||
|
"nullable": True,
|
||||||
|
"description": "The second input. It's a big list with a single-line description.",
|
||||||
|
},
|
||||||
|
"z": {
|
||||||
|
"type": "array",
|
||||||
|
"prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}],
|
||||||
|
"description": "The third input. It's some kind of tuple with a default arg.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["x", "y"],
|
||||||
|
},
|
||||||
|
"return": {
|
||||||
|
"type": "array",
|
||||||
|
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||||
|
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.assertEqual(schema["function"], expected_schema)
|
||||||
Reference in New Issue
Block a user