Tool types (#24032)
* Tool types * Tests + fixes * Isolate types * Oops * Review comments + docs * Tests + docs * soundfile -> vision
This commit is contained in:
@@ -70,3 +70,32 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens
|
|||||||
### launch_gradio_demo
|
### launch_gradio_demo
|
||||||
|
|
||||||
[[autodoc]] launch_gradio_demo
|
[[autodoc]] launch_gradio_demo
|
||||||
|
|
||||||
|
## Agent Types
|
||||||
|
|
||||||
|
Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return
|
||||||
|
text, image, audio, video, among other types. In order to increase compatibility between tools, as well as to
|
||||||
|
correctly render these returns in ipython (jupyter, colab, ipython notebooks, ...), we implement wrapper classes
|
||||||
|
around these types.
|
||||||
|
|
||||||
|
The wrapped objects should continue behaving as initially; a text object should still behave as a string, an image
|
||||||
|
object should still behave as a `PIL.Image`.
|
||||||
|
|
||||||
|
These types have three specific purposes:
|
||||||
|
|
||||||
|
- Calling `to_raw` on the type should return the underlying object
|
||||||
|
- Calling `to_string` on the type should return the object as a string: that can be the string in case of an `AgentText`
|
||||||
|
but will be the path of the serialized version of the object in other instances
|
||||||
|
- Displaying it in an ipython kernel should display the object correctly
|
||||||
|
|
||||||
|
### AgentText
|
||||||
|
|
||||||
|
[[autodoc]] transformers.tools.agent_types.AgentText
|
||||||
|
|
||||||
|
### AgentImage
|
||||||
|
|
||||||
|
[[autodoc]] transformers.tools.agent_types.AgentImage
|
||||||
|
|
||||||
|
### AgentAudio
|
||||||
|
|
||||||
|
[[autodoc]] transformers.tools.agent_types.AgentAudio
|
||||||
|
|||||||
277
src/transformers/tools/agent_types.py
Normal file
277
src/transformers/tools/agent_types.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL.Image
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.Image import Image as ImageType
|
||||||
|
else:
|
||||||
|
ImageType = object
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_soundfile_availble():
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType:
|
||||||
|
"""
|
||||||
|
Abstract class to be reimplemented to define types that can be returned by agents.
|
||||||
|
|
||||||
|
These objects serve three purposes:
|
||||||
|
|
||||||
|
- They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
|
||||||
|
- They can be stringified: str(object) in order to return a string defining the object
|
||||||
|
- They should be displayed correctly in ipython notebooks/colab/jupyter
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.to_string()
|
||||||
|
|
||||||
|
def to_raw(self):
|
||||||
|
logger.error(
|
||||||
|
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
||||||
|
)
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
logger.error(
|
||||||
|
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
|
||||||
|
)
|
||||||
|
return str(self._value)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentText(AgentType, str):
|
||||||
|
"""
|
||||||
|
Text type returned by the agent. Behaves as a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_raw(self):
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
|
class AgentImage(AgentType, ImageType):
|
||||||
|
"""
|
||||||
|
Image type returned by the agent. Behaves as a PIL.Image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
super().__init__(value)
|
||||||
|
|
||||||
|
if not is_vision_available():
|
||||||
|
raise ImportError("PIL must be installed in order to handle images.")
|
||||||
|
|
||||||
|
self._path = None
|
||||||
|
self._raw = None
|
||||||
|
self._tensor = None
|
||||||
|
|
||||||
|
if isinstance(value, ImageType):
|
||||||
|
self._raw = value
|
||||||
|
elif isinstance(value, (str, pathlib.Path)):
|
||||||
|
self._path = value
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
self._tensor = value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
|
||||||
|
|
||||||
|
def _ipython_display_(self, include=None, exclude=None):
|
||||||
|
"""
|
||||||
|
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
||||||
|
"""
|
||||||
|
from IPython.display import Image, display
|
||||||
|
|
||||||
|
display(Image(self.to_string()))
|
||||||
|
|
||||||
|
def to_raw(self):
|
||||||
|
"""
|
||||||
|
Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
|
||||||
|
"""
|
||||||
|
if self._raw is not None:
|
||||||
|
return self._raw
|
||||||
|
|
||||||
|
if self._path is not None:
|
||||||
|
self._raw = Image.open(self._path)
|
||||||
|
return self._raw
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
"""
|
||||||
|
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
|
||||||
|
version of the image.
|
||||||
|
"""
|
||||||
|
if self._path is not None:
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
if self._raw is not None:
|
||||||
|
directory = tempfile.mkdtemp()
|
||||||
|
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||||
|
self._raw.save(self._path)
|
||||||
|
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
if self._tensor is not None:
|
||||||
|
array = self._tensor.cpu().detach().numpy()
|
||||||
|
|
||||||
|
# There is likely simpler than load into image into save
|
||||||
|
img = Image.fromarray((array * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
directory = tempfile.mkdtemp()
|
||||||
|
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
|
||||||
|
|
||||||
|
img.save(self._path)
|
||||||
|
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAudio(AgentType):
|
||||||
|
"""
|
||||||
|
Audio type returned by the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value, samplerate=16_000):
|
||||||
|
super().__init__(value)
|
||||||
|
|
||||||
|
if not is_soundfile_availble():
|
||||||
|
raise ImportError("soundfile must be installed in order to handle audio.")
|
||||||
|
|
||||||
|
self._path = None
|
||||||
|
self._tensor = None
|
||||||
|
|
||||||
|
self.samplerate = samplerate
|
||||||
|
|
||||||
|
if isinstance(value, (str, pathlib.Path)):
|
||||||
|
self._path = value
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
self._tensor = value
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported audio type: {type(value)}")
|
||||||
|
|
||||||
|
def _ipython_display_(self, include=None, exclude=None):
|
||||||
|
"""
|
||||||
|
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
|
||||||
|
"""
|
||||||
|
from IPython.display import Audio, display
|
||||||
|
|
||||||
|
display(Audio(self.to_string(), rate=self.samplerate))
|
||||||
|
|
||||||
|
def to_raw(self):
|
||||||
|
"""
|
||||||
|
Returns the "raw" version of that object. It is a `torch.Tensor` object.
|
||||||
|
"""
|
||||||
|
if self._tensor is not None:
|
||||||
|
return self._tensor
|
||||||
|
|
||||||
|
if self._path is not None:
|
||||||
|
tensor, self.samplerate = sf.read(self._path)
|
||||||
|
self._tensor = torch.tensor(tensor)
|
||||||
|
return self._tensor
|
||||||
|
|
||||||
|
def to_string(self):
|
||||||
|
"""
|
||||||
|
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
|
||||||
|
version of the audio.
|
||||||
|
"""
|
||||||
|
if self._path is not None:
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
if self._tensor is not None:
|
||||||
|
directory = tempfile.mkdtemp()
|
||||||
|
self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
|
||||||
|
sf.write(self._path, self._tensor, samplerate=self.samplerate)
|
||||||
|
return self._path
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||||
|
INSTANCE_TYPE_MAPPING = {str: AgentText}
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage
|
||||||
|
|
||||||
|
|
||||||
|
def handle_agent_inputs(*args, **kwargs):
|
||||||
|
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
|
||||||
|
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def handle_agent_outputs(outputs, output_types=None):
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
decoded_outputs = {}
|
||||||
|
for i, (k, v) in enumerate(outputs.items()):
|
||||||
|
if output_types is not None:
|
||||||
|
# If the class has defined outputs, we can map directly according to the class definition
|
||||||
|
if output_types[i] in AGENT_TYPE_MAPPING:
|
||||||
|
decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v)
|
||||||
|
else:
|
||||||
|
decoded_outputs[k] = AgentType(v)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If the class does not have defined output, then we map according to the type
|
||||||
|
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||||
|
if isinstance(v, _k):
|
||||||
|
decoded_outputs[k] = _v(v)
|
||||||
|
if k not in decoded_outputs:
|
||||||
|
decoded_outputs[k] = AgentType[v]
|
||||||
|
|
||||||
|
elif isinstance(outputs, (list, tuple)):
|
||||||
|
decoded_outputs = type(outputs)()
|
||||||
|
for i, v in enumerate(outputs):
|
||||||
|
if output_types is not None:
|
||||||
|
# If the class has defined outputs, we can map directly according to the class definition
|
||||||
|
if output_types[i] in AGENT_TYPE_MAPPING:
|
||||||
|
decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v))
|
||||||
|
else:
|
||||||
|
decoded_outputs.append(AgentType(v))
|
||||||
|
else:
|
||||||
|
# If the class does not have defined output, then we map according to the type
|
||||||
|
found = False
|
||||||
|
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||||
|
if isinstance(v, _k):
|
||||||
|
decoded_outputs.append(_v(v))
|
||||||
|
found = True
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
decoded_outputs.append(AgentType(v))
|
||||||
|
|
||||||
|
else:
|
||||||
|
if output_types[0] in AGENT_TYPE_MAPPING:
|
||||||
|
# If the class has defined outputs, we can map directly according to the class definition
|
||||||
|
decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If the class does not have defined output, then we map according to the type
|
||||||
|
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||||
|
if isinstance(outputs, _k):
|
||||||
|
return _v(outputs)
|
||||||
|
return AgentType(outputs)
|
||||||
|
|
||||||
|
return decoded_outputs
|
||||||
@@ -37,6 +37,7 @@ from ..utils import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from .agent_types import handle_agent_inputs, handle_agent_outputs
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -413,6 +414,8 @@ class RemoteTool(Tool):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||||
|
|
||||||
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
||||||
inputs = self.prepare_inputs(*args, **kwargs)
|
inputs = self.prepare_inputs(*args, **kwargs)
|
||||||
if isinstance(inputs, dict):
|
if isinstance(inputs, dict):
|
||||||
@@ -421,6 +424,9 @@ class RemoteTool(Tool):
|
|||||||
outputs = self.client(inputs, output_image=output_image)
|
outputs = self.client(inputs, output_image=output_image)
|
||||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
|
|
||||||
|
outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)
|
||||||
|
|
||||||
return self.extract_outputs(outputs)
|
return self.extract_outputs(outputs)
|
||||||
|
|
||||||
|
|
||||||
@@ -550,6 +556,8 @@ class PipelineTool(Tool):
|
|||||||
return self.post_processor(outputs)
|
return self.post_processor(outputs)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||||
|
|
||||||
if not self.is_initialized:
|
if not self.is_initialized:
|
||||||
self.setup()
|
self.setup()
|
||||||
|
|
||||||
@@ -557,7 +565,9 @@ class PipelineTool(Tool):
|
|||||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||||
outputs = self.forward(encoded_inputs)
|
outputs = self.forward(encoded_inputs)
|
||||||
outputs = send_to_device(outputs, "cpu")
|
outputs = send_to_device(outputs, "cpu")
|
||||||
return self.decode(outputs)
|
decoded_outputs = self.decode(outputs)
|
||||||
|
|
||||||
|
return handle_agent_outputs(decoded_outputs, self.outputs)
|
||||||
|
|
||||||
|
|
||||||
def launch_gradio_demo(tool_class: Tool):
|
def launch_gradio_demo(tool_class: Tool):
|
||||||
|
|||||||
121
tests/tools/test_agent_types.py
Normal file
121
tests/tools/test_agent_types.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||||
|
from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText
|
||||||
|
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_soundfile_availble():
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_path(suffix="") -> str:
|
||||||
|
directory = tempfile.mkdtemp()
|
||||||
|
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||||
|
|
||||||
|
|
||||||
|
@require_soundfile
|
||||||
|
@require_torch
|
||||||
|
class AgentAudioTests(unittest.TestCase):
|
||||||
|
def test_from_tensor(self):
|
||||||
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
|
agent_type = AgentAudio(tensor)
|
||||||
|
path = str(agent_type.to_string())
|
||||||
|
|
||||||
|
# Ensure that the tensor and the agent_type's tensor are the same
|
||||||
|
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
||||||
|
|
||||||
|
del agent_type
|
||||||
|
|
||||||
|
# Ensure the path remains even after the object deletion
|
||||||
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
|
# Ensure that the file contains the same value as the original tensor
|
||||||
|
new_tensor, _ = sf.read(path)
|
||||||
|
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
|
||||||
|
|
||||||
|
def test_from_string(self):
|
||||||
|
tensor = torch.rand(12, dtype=torch.float64) - 0.5
|
||||||
|
path = get_new_path(suffix=".wav")
|
||||||
|
sf.write(path, tensor, 16000)
|
||||||
|
|
||||||
|
agent_type = AgentAudio(path)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
|
||||||
|
self.assertEqual(agent_type.to_string(), path)
|
||||||
|
|
||||||
|
|
||||||
|
@require_vision
|
||||||
|
@require_torch
|
||||||
|
class AgentImageTests(unittest.TestCase):
|
||||||
|
def test_from_tensor(self):
|
||||||
|
tensor = torch.randint(0, 256, (64, 64, 3))
|
||||||
|
agent_type = AgentImage(tensor)
|
||||||
|
path = str(agent_type.to_string())
|
||||||
|
|
||||||
|
# Ensure that the tensor and the agent_type's tensor are the same
|
||||||
|
self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4))
|
||||||
|
|
||||||
|
self.assertIsInstance(agent_type.to_raw(), Image.Image)
|
||||||
|
|
||||||
|
# Ensure the path remains even after the object deletion
|
||||||
|
del agent_type
|
||||||
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
|
def test_from_string(self):
|
||||||
|
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||||
|
image = Image.open(path)
|
||||||
|
agent_type = AgentImage(path)
|
||||||
|
|
||||||
|
self.assertTrue(path.samefile(agent_type.to_string()))
|
||||||
|
self.assertTrue(image == agent_type.to_raw())
|
||||||
|
|
||||||
|
# Ensure the path remains even after the object deletion
|
||||||
|
del agent_type
|
||||||
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
|
def test_from_image(self):
|
||||||
|
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||||
|
image = Image.open(path)
|
||||||
|
agent_type = AgentImage(image)
|
||||||
|
|
||||||
|
self.assertFalse(path.samefile(agent_type.to_string()))
|
||||||
|
self.assertTrue(image == agent_type.to_raw())
|
||||||
|
|
||||||
|
# Ensure the path remains even after the object deletion
|
||||||
|
del agent_type
|
||||||
|
self.assertTrue(os.path.exists(path))
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTextTests(unittest.TestCase):
|
||||||
|
def test_from_string(self):
|
||||||
|
string = "Hey!"
|
||||||
|
agent_type = AgentText(string)
|
||||||
|
|
||||||
|
self.assertEqual(string, agent_type.to_string())
|
||||||
|
self.assertEqual(string, agent_type.to_raw())
|
||||||
|
self.assertEqual(string, agent_type)
|
||||||
@@ -30,28 +30,27 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||||
image = dataset[0]["image"]
|
document = dataset[0]["image"]
|
||||||
|
|
||||||
result = self.tool(image, "When is the coffee break?")
|
result = self.tool(document, "When is the coffee break?")
|
||||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||||
|
|
||||||
def test_exact_match_arg_remote(self):
|
def test_exact_match_arg_remote(self):
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||||
image = dataset[0]["image"]
|
document = dataset[0]["image"]
|
||||||
|
|
||||||
result = self.remote_tool(image, "When is the coffee break?")
|
result = self.remote_tool(document, "When is the coffee break?")
|
||||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||||
|
|
||||||
def test_exact_match_kwarg(self):
|
def test_exact_match_kwarg(self):
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||||
image = dataset[0]["image"]
|
document = dataset[0]["image"]
|
||||||
|
|
||||||
result = self.tool(image=image, question="When is the coffee break?")
|
self.tool(document=document, question="When is the coffee break?")
|
||||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
|
||||||
|
|
||||||
def test_exact_match_kwarg_remote(self):
|
def test_exact_match_kwarg_remote(self):
|
||||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||||
image = dataset[0]["image"]
|
document = dataset[0]["image"]
|
||||||
|
|
||||||
result = self.remote_tool(image=image, question="When is the coffee break?")
|
result = self.remote_tool(document=document, question="When is the coffee break?")
|
||||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||||
|
|||||||
@@ -37,9 +37,11 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
# SpeechT5 isn't deterministic
|
# SpeechT5 isn't deterministic
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
result = self.tool("hey")
|
result = self.tool("hey")
|
||||||
|
resulting_tensor = result.to_raw()
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
resulting_tensor[:3],
|
||||||
|
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,8 +49,10 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
# SpeechT5 isn't deterministic
|
# SpeechT5 isn't deterministic
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
result = self.tool("hey")
|
result = self.tool("hey")
|
||||||
|
resulting_tensor = result.to_raw()
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
resulting_tensor[:3],
|
||||||
|
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import List
|
|||||||
|
|
||||||
from transformers import is_torch_available, is_vision_available
|
from transformers import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||||
|
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -54,11 +55,11 @@ def output_types(outputs: List):
|
|||||||
output_types = []
|
output_types = []
|
||||||
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
if isinstance(output, str):
|
if isinstance(output, (str, AgentText)):
|
||||||
output_types.append("text")
|
output_types.append("text")
|
||||||
elif isinstance(output, Image.Image):
|
elif isinstance(output, (Image.Image, AgentImage)):
|
||||||
output_types.append("image")
|
output_types.append("image")
|
||||||
elif isinstance(output, torch.Tensor):
|
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||||
output_types.append("audio")
|
output_types.append("audio")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid output: {output}")
|
raise ValueError(f"Invalid output: {output}")
|
||||||
@@ -98,3 +99,35 @@ class ToolTesterMixin:
|
|||||||
self.assertTrue(hasattr(self.tool, "description"))
|
self.assertTrue(hasattr(self.tool, "description"))
|
||||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
||||||
|
|
||||||
|
def test_agent_types_outputs(self):
|
||||||
|
inputs = create_inputs(self.tool.inputs)
|
||||||
|
outputs = self.tool(*inputs)
|
||||||
|
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||||
|
|
||||||
|
for output, output_type in zip(outputs, self.tool.outputs):
|
||||||
|
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||||
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
|
def test_agent_types_inputs(self):
|
||||||
|
inputs = create_inputs(self.tool.inputs)
|
||||||
|
|
||||||
|
_inputs = []
|
||||||
|
|
||||||
|
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||||
|
if isinstance(input_type, list):
|
||||||
|
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||||
|
else:
|
||||||
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
outputs = self.tool(*inputs)
|
||||||
|
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import load_tool
|
from transformers import load_tool
|
||||||
|
from transformers.tools.agent_types import AGENT_TYPE_MAPPING
|
||||||
|
|
||||||
from .test_tools_common import ToolTesterMixin, output_types
|
from .test_tools_common import ToolTesterMixin, output_types
|
||||||
|
|
||||||
@@ -51,3 +52,35 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
|||||||
outputs = [outputs]
|
outputs = [outputs]
|
||||||
|
|
||||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||||
|
|
||||||
|
def test_agent_types_outputs(self):
|
||||||
|
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||||
|
outputs = self.tool(*inputs)
|
||||||
|
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||||
|
|
||||||
|
for output, output_type in zip(outputs, self.tool.outputs):
|
||||||
|
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||||
|
self.assertTrue(isinstance(output, agent_type))
|
||||||
|
|
||||||
|
def test_agent_types_inputs(self):
|
||||||
|
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||||
|
|
||||||
|
_inputs = []
|
||||||
|
|
||||||
|
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||||
|
if isinstance(input_type, list):
|
||||||
|
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||||
|
else:
|
||||||
|
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||||
|
|
||||||
|
# Should not raise an error
|
||||||
|
outputs = self.tool(*inputs)
|
||||||
|
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
outputs = [outputs]
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||||
|
|||||||
Reference in New Issue
Block a user