Tool types (#24032)
* Tool types * Tests + fixes * Isolate types * Oops * Review comments + docs * Tests + docs * soundfile -> vision
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -54,11 +55,11 @@ def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, str):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, Image.Image):
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, torch.Tensor):
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
@@ -98,3 +99,35 @@ class ToolTesterMixin:
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
||||
|
||||
def test_agent_types_outputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
for output, output_type in zip(outputs, self.tool.outputs):
|
||||
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
|
||||
_inputs = []
|
||||
|
||||
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
Reference in New Issue
Block a user