Tool types (#24032)

* Tool types

* Tests + fixes

* Isolate types

* Oops

* Review comments + docs

* Tests + docs

* soundfile -> vision
This commit is contained in:
Lysandre Debut
2023-06-09 13:34:07 -04:00
committed by GitHub
parent 061580c82c
commit deff5979fe
8 changed files with 521 additions and 15 deletions

View File

@@ -18,6 +18,7 @@ from typing import List
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir, is_tool_test
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
if is_torch_available():
@@ -54,11 +55,11 @@ def output_types(outputs: List):
output_types = []
for output in outputs:
if isinstance(output, str):
if isinstance(output, (str, AgentText)):
output_types.append("text")
elif isinstance(output, Image.Image):
elif isinstance(output, (Image.Image, AgentImage)):
output_types.append("image")
elif isinstance(output, torch.Tensor):
elif isinstance(output, (torch.Tensor, AgentAudio)):
output_types.append("audio")
else:
raise ValueError(f"Invalid output: {output}")
@@ -98,3 +99,35 @@ class ToolTesterMixin:
self.assertTrue(hasattr(self.tool, "description"))
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
self.assertTrue(self.tool.description.startswith("This is a tool that"))
def test_agent_types_outputs(self):
inputs = create_inputs(self.tool.inputs)
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))
for output, output_type in zip(outputs, self.tool.outputs):
agent_type = AGENT_TYPE_MAPPING[output_type]
self.assertTrue(isinstance(output, agent_type))
def test_agent_types_inputs(self):
inputs = create_inputs(self.tool.inputs)
_inputs = []
for _input, input_type in zip(inputs, self.tool.inputs):
if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))