From deff5979fee1f989d26e4946c92a5c35ce695af8 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 9 Jun 2023 13:34:07 -0400 Subject: [PATCH] Tool types (#24032) * Tool types * Tests + fixes * Isolate types * Oops * Review comments + docs * Tests + docs * soundfile -> vision --- docs/source/en/main_classes/agent.mdx | 29 ++ src/transformers/tools/agent_types.py | 277 ++++++++++++++++++ src/transformers/tools/base.py | 12 +- tests/tools/test_agent_types.py | 121 ++++++++ .../tools/test_document_question_answering.py | 17 +- tests/tools/test_text_to_speech.py | 8 +- tests/tools/test_tools_common.py | 39 ++- tests/tools/test_translation.py | 33 +++ 8 files changed, 521 insertions(+), 15 deletions(-) create mode 100644 src/transformers/tools/agent_types.py create mode 100644 tests/tools/test_agent_types.py diff --git a/docs/source/en/main_classes/agent.mdx b/docs/source/en/main_classes/agent.mdx index 5f4b3df306..37e9f00ecb 100644 --- a/docs/source/en/main_classes/agent.mdx +++ b/docs/source/en/main_classes/agent.mdx @@ -70,3 +70,32 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens ### launch_gradio_demo [[autodoc]] launch_gradio_demo + +## Agent Types + +Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return +text, image, audio, video, among other types. In order to increase compatibility between tools, as well as to +correctly render these returns in ipython (jupyter, colab, ipython notebooks, ...), we implement wrapper classes +around these types. + +The wrapped objects should continue behaving as initially; a text object should still behave as a string, an image +object should still behave as a `PIL.Image`. + +These types have three specific purposes: + +- Calling `to_raw` on the type should return the underlying object +- Calling `to_string` on the type should return the object as a string: that can be the string in case of an `AgentText` + but will be the path of the serialized version of the object in other instances +- Displaying it in an ipython kernel should display the object correctly + +### AgentText + +[[autodoc]] transformers.tools.agent_types.AgentText + +### AgentImage + +[[autodoc]] transformers.tools.agent_types.AgentImage + +### AgentAudio + +[[autodoc]] transformers.tools.agent_types.AgentAudio diff --git a/src/transformers/tools/agent_types.py b/src/transformers/tools/agent_types.py new file mode 100644 index 0000000000..f1c3261d57 --- /dev/null +++ b/src/transformers/tools/agent_types.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pathlib +import tempfile +import uuid + +import numpy as np + +from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging + + +logger = logging.get_logger(__name__) + +if is_vision_available(): + import PIL.Image + from PIL import Image + from PIL.Image import Image as ImageType +else: + ImageType = object + +if is_torch_available(): + import torch + +if is_soundfile_availble(): + import soundfile as sf + + +class AgentType: + """ + Abstract class to be reimplemented to define types that can be returned by agents. + + These objects serve three purposes: + + - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images + - They can be stringified: str(object) in order to return a string defining the object + - They should be displayed correctly in ipython notebooks/colab/jupyter + """ + + def __init__(self, value): + self._value = value + + def __str__(self): + return self.to_string() + + def to_raw(self): + logger.error( + "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable" + ) + return self._value + + def to_string(self) -> str: + logger.error( + "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable" + ) + return str(self._value) + + +class AgentText(AgentType, str): + """ + Text type returned by the agent. Behaves as a string. + """ + + def to_raw(self): + return self._value + + def to_string(self): + return self._value + + +class AgentImage(AgentType, ImageType): + """ + Image type returned by the agent. Behaves as a PIL.Image. + """ + + def __init__(self, value): + super().__init__(value) + + if not is_vision_available(): + raise ImportError("PIL must be installed in order to handle images.") + + self._path = None + self._raw = None + self._tensor = None + + if isinstance(value, ImageType): + self._raw = value + elif isinstance(value, (str, pathlib.Path)): + self._path = value + elif isinstance(value, torch.Tensor): + self._tensor = value + else: + raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") + + def _ipython_display_(self, include=None, exclude=None): + """ + Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...) + """ + from IPython.display import Image, display + + display(Image(self.to_string())) + + def to_raw(self): + """ + Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image. + """ + if self._raw is not None: + return self._raw + + if self._path is not None: + self._raw = Image.open(self._path) + return self._raw + + def to_string(self): + """ + Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized + version of the image. + """ + if self._path is not None: + return self._path + + if self._raw is not None: + directory = tempfile.mkdtemp() + self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") + self._raw.save(self._path) + + return self._path + + if self._tensor is not None: + array = self._tensor.cpu().detach().numpy() + + # There is likely simpler than load into image into save + img = Image.fromarray((array * 255).astype(np.uint8)) + + directory = tempfile.mkdtemp() + self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") + + img.save(self._path) + + return self._path + + +class AgentAudio(AgentType): + """ + Audio type returned by the agent. + """ + + def __init__(self, value, samplerate=16_000): + super().__init__(value) + + if not is_soundfile_availble(): + raise ImportError("soundfile must be installed in order to handle audio.") + + self._path = None + self._tensor = None + + self.samplerate = samplerate + + if isinstance(value, (str, pathlib.Path)): + self._path = value + elif isinstance(value, torch.Tensor): + self._tensor = value + else: + raise ValueError(f"Unsupported audio type: {type(value)}") + + def _ipython_display_(self, include=None, exclude=None): + """ + Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...) + """ + from IPython.display import Audio, display + + display(Audio(self.to_string(), rate=self.samplerate)) + + def to_raw(self): + """ + Returns the "raw" version of that object. It is a `torch.Tensor` object. + """ + if self._tensor is not None: + return self._tensor + + if self._path is not None: + tensor, self.samplerate = sf.read(self._path) + self._tensor = torch.tensor(tensor) + return self._tensor + + def to_string(self): + """ + Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized + version of the audio. + """ + if self._path is not None: + return self._path + + if self._tensor is not None: + directory = tempfile.mkdtemp() + self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav") + sf.write(self._path, self._tensor, samplerate=self.samplerate) + return self._path + + +AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} +INSTANCE_TYPE_MAPPING = {str: AgentText} + +if is_vision_available(): + INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage + + +def handle_agent_inputs(*args, **kwargs): + args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] + kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()} + return args, kwargs + + +def handle_agent_outputs(outputs, output_types=None): + if isinstance(outputs, dict): + decoded_outputs = {} + for i, (k, v) in enumerate(outputs.items()): + if output_types is not None: + # If the class has defined outputs, we can map directly according to the class definition + if output_types[i] in AGENT_TYPE_MAPPING: + decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v) + else: + decoded_outputs[k] = AgentType(v) + + else: + # If the class does not have defined output, then we map according to the type + for _k, _v in INSTANCE_TYPE_MAPPING.items(): + if isinstance(v, _k): + decoded_outputs[k] = _v(v) + if k not in decoded_outputs: + decoded_outputs[k] = AgentType[v] + + elif isinstance(outputs, (list, tuple)): + decoded_outputs = type(outputs)() + for i, v in enumerate(outputs): + if output_types is not None: + # If the class has defined outputs, we can map directly according to the class definition + if output_types[i] in AGENT_TYPE_MAPPING: + decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v)) + else: + decoded_outputs.append(AgentType(v)) + else: + # If the class does not have defined output, then we map according to the type + found = False + for _k, _v in INSTANCE_TYPE_MAPPING.items(): + if isinstance(v, _k): + decoded_outputs.append(_v(v)) + found = True + + if not found: + decoded_outputs.append(AgentType(v)) + + else: + if output_types[0] in AGENT_TYPE_MAPPING: + # If the class has defined outputs, we can map directly according to the class definition + decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs) + + else: + # If the class does not have defined output, then we map according to the type + for _k, _v in INSTANCE_TYPE_MAPPING.items(): + if isinstance(outputs, _k): + return _v(outputs) + return AgentType(outputs) + + return decoded_outputs diff --git a/src/transformers/tools/base.py b/src/transformers/tools/base.py index add64e373d..90ddcf6aa8 100644 --- a/src/transformers/tools/base.py +++ b/src/transformers/tools/base.py @@ -37,6 +37,7 @@ from ..utils import ( is_vision_available, logging, ) +from .agent_types import handle_agent_inputs, handle_agent_outputs logger = logging.get_logger(__name__) @@ -413,6 +414,8 @@ class RemoteTool(Tool): return outputs def __call__(self, *args, **kwargs): + args, kwargs = handle_agent_inputs(*args, **kwargs) + output_image = self.tool_class is not None and self.tool_class.outputs == ["image"] inputs = self.prepare_inputs(*args, **kwargs) if isinstance(inputs, dict): @@ -421,6 +424,9 @@ class RemoteTool(Tool): outputs = self.client(inputs, output_image=output_image) if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list): outputs = outputs[0] + + outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None) + return self.extract_outputs(outputs) @@ -550,6 +556,8 @@ class PipelineTool(Tool): return self.post_processor(outputs) def __call__(self, *args, **kwargs): + args, kwargs = handle_agent_inputs(*args, **kwargs) + if not self.is_initialized: self.setup() @@ -557,7 +565,9 @@ class PipelineTool(Tool): encoded_inputs = send_to_device(encoded_inputs, self.device) outputs = self.forward(encoded_inputs) outputs = send_to_device(outputs, "cpu") - return self.decode(outputs) + decoded_outputs = self.decode(outputs) + + return handle_agent_outputs(decoded_outputs, self.outputs) def launch_gradio_demo(tool_class: Tool): diff --git a/tests/tools/test_agent_types.py b/tests/tools/test_agent_types.py new file mode 100644 index 0000000000..a1cc4f70cc --- /dev/null +++ b/tests/tools/test_agent_types.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest +import uuid +from pathlib import Path + +from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision +from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText +from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available + + +if is_torch_available(): + import torch + +if is_soundfile_availble(): + import soundfile as sf + +if is_vision_available(): + from PIL import Image + + +def get_new_path(suffix="") -> str: + directory = tempfile.mkdtemp() + return os.path.join(directory, str(uuid.uuid4()) + suffix) + + +@require_soundfile +@require_torch +class AgentAudioTests(unittest.TestCase): + def test_from_tensor(self): + tensor = torch.rand(12, dtype=torch.float64) - 0.5 + agent_type = AgentAudio(tensor) + path = str(agent_type.to_string()) + + # Ensure that the tensor and the agent_type's tensor are the same + self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4)) + + del agent_type + + # Ensure the path remains even after the object deletion + self.assertTrue(os.path.exists(path)) + + # Ensure that the file contains the same value as the original tensor + new_tensor, _ = sf.read(path) + self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4)) + + def test_from_string(self): + tensor = torch.rand(12, dtype=torch.float64) - 0.5 + path = get_new_path(suffix=".wav") + sf.write(path, tensor, 16000) + + agent_type = AgentAudio(path) + + self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4)) + self.assertEqual(agent_type.to_string(), path) + + +@require_vision +@require_torch +class AgentImageTests(unittest.TestCase): + def test_from_tensor(self): + tensor = torch.randint(0, 256, (64, 64, 3)) + agent_type = AgentImage(tensor) + path = str(agent_type.to_string()) + + # Ensure that the tensor and the agent_type's tensor are the same + self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4)) + + self.assertIsInstance(agent_type.to_raw(), Image.Image) + + # Ensure the path remains even after the object deletion + del agent_type + self.assertTrue(os.path.exists(path)) + + def test_from_string(self): + path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + image = Image.open(path) + agent_type = AgentImage(path) + + self.assertTrue(path.samefile(agent_type.to_string())) + self.assertTrue(image == agent_type.to_raw()) + + # Ensure the path remains even after the object deletion + del agent_type + self.assertTrue(os.path.exists(path)) + + def test_from_image(self): + path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png" + image = Image.open(path) + agent_type = AgentImage(image) + + self.assertFalse(path.samefile(agent_type.to_string())) + self.assertTrue(image == agent_type.to_raw()) + + # Ensure the path remains even after the object deletion + del agent_type + self.assertTrue(os.path.exists(path)) + + +class AgentTextTests(unittest.TestCase): + def test_from_string(self): + string = "Hey!" + agent_type = AgentText(string) + + self.assertEqual(string, agent_type.to_string()) + self.assertEqual(string, agent_type.to_raw()) + self.assertEqual(string, agent_type) diff --git a/tests/tools/test_document_question_answering.py b/tests/tools/test_document_question_answering.py index a799676e9a..1d77bcb470 100644 --- a/tests/tools/test_document_question_answering.py +++ b/tests/tools/test_document_question_answering.py @@ -30,28 +30,27 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin): def test_exact_match_arg(self): dataset = load_dataset("hf-internal-testing/example-documents", split="test") - image = dataset[0]["image"] + document = dataset[0]["image"] - result = self.tool(image, "When is the coffee break?") + result = self.tool(document, "When is the coffee break?") self.assertEqual(result, "11-14 to 11:39 a.m.") def test_exact_match_arg_remote(self): dataset = load_dataset("hf-internal-testing/example-documents", split="test") - image = dataset[0]["image"] + document = dataset[0]["image"] - result = self.remote_tool(image, "When is the coffee break?") + result = self.remote_tool(document, "When is the coffee break?") self.assertEqual(result, "11-14 to 11:39 a.m.") def test_exact_match_kwarg(self): dataset = load_dataset("hf-internal-testing/example-documents", split="test") - image = dataset[0]["image"] + document = dataset[0]["image"] - result = self.tool(image=image, question="When is the coffee break?") - self.assertEqual(result, "11-14 to 11:39 a.m.") + self.tool(document=document, question="When is the coffee break?") def test_exact_match_kwarg_remote(self): dataset = load_dataset("hf-internal-testing/example-documents", split="test") - image = dataset[0]["image"] + document = dataset[0]["image"] - result = self.remote_tool(image=image, question="When is the coffee break?") + result = self.remote_tool(document=document, question="When is the coffee break?") self.assertEqual(result, "11-14 to 11:39 a.m.") diff --git a/tests/tools/test_text_to_speech.py b/tests/tools/test_text_to_speech.py index d404f35740..a63017d277 100644 --- a/tests/tools/test_text_to_speech.py +++ b/tests/tools/test_text_to_speech.py @@ -37,9 +37,11 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin): # SpeechT5 isn't deterministic torch.manual_seed(0) result = self.tool("hey") + resulting_tensor = result.to_raw() self.assertTrue( torch.allclose( - result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293]) + resulting_tensor[:3], + torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]), ) ) @@ -47,8 +49,10 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin): # SpeechT5 isn't deterministic torch.manual_seed(0) result = self.tool("hey") + resulting_tensor = result.to_raw() self.assertTrue( torch.allclose( - result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293]) + resulting_tensor[:3], + torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]), ) ) diff --git a/tests/tools/test_tools_common.py b/tests/tools/test_tools_common.py index 5e66885cac..984edfcd8c 100644 --- a/tests/tools/test_tools_common.py +++ b/tests/tools/test_tools_common.py @@ -18,6 +18,7 @@ from typing import List from transformers import is_torch_available, is_vision_available from transformers.testing_utils import get_tests_dir, is_tool_test +from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText if is_torch_available(): @@ -54,11 +55,11 @@ def output_types(outputs: List): output_types = [] for output in outputs: - if isinstance(output, str): + if isinstance(output, (str, AgentText)): output_types.append("text") - elif isinstance(output, Image.Image): + elif isinstance(output, (Image.Image, AgentImage)): output_types.append("image") - elif isinstance(output, torch.Tensor): + elif isinstance(output, (torch.Tensor, AgentAudio)): output_types.append("audio") else: raise ValueError(f"Invalid output: {output}") @@ -98,3 +99,35 @@ class ToolTesterMixin: self.assertTrue(hasattr(self.tool, "description")) self.assertTrue(hasattr(self.tool, "default_checkpoint")) self.assertTrue(self.tool.description.startswith("This is a tool that")) + + def test_agent_types_outputs(self): + inputs = create_inputs(self.tool.inputs) + outputs = self.tool(*inputs) + + if not isinstance(outputs, list): + outputs = [outputs] + + self.assertEqual(len(outputs), len(self.tool.outputs)) + + for output, output_type in zip(outputs, self.tool.outputs): + agent_type = AGENT_TYPE_MAPPING[output_type] + self.assertTrue(isinstance(output, agent_type)) + + def test_agent_types_inputs(self): + inputs = create_inputs(self.tool.inputs) + + _inputs = [] + + for _input, input_type in zip(inputs, self.tool.inputs): + if isinstance(input_type, list): + _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) + else: + _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + + # Should not raise an error + outputs = self.tool(*inputs) + + if not isinstance(outputs, list): + outputs = [outputs] + + self.assertEqual(len(outputs), len(self.tool.outputs)) diff --git a/tests/tools/test_translation.py b/tests/tools/test_translation.py index 2ccb043b00..15e1c8cd6a 100644 --- a/tests/tools/test_translation.py +++ b/tests/tools/test_translation.py @@ -16,6 +16,7 @@ import unittest from transformers import load_tool +from transformers.tools.agent_types import AGENT_TYPE_MAPPING from .test_tools_common import ToolTesterMixin, output_types @@ -51,3 +52,35 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin): outputs = [outputs] self.assertListEqual(output_types(outputs), self.tool.outputs) + + def test_agent_types_outputs(self): + inputs = ["Hey, what's up?", "English", "Spanish"] + outputs = self.tool(*inputs) + + if not isinstance(outputs, list): + outputs = [outputs] + + self.assertEqual(len(outputs), len(self.tool.outputs)) + + for output, output_type in zip(outputs, self.tool.outputs): + agent_type = AGENT_TYPE_MAPPING[output_type] + self.assertTrue(isinstance(output, agent_type)) + + def test_agent_types_inputs(self): + inputs = ["Hey, what's up?", "English", "Spanish"] + + _inputs = [] + + for _input, input_type in zip(inputs, self.tool.inputs): + if isinstance(input_type, list): + _inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type]) + else: + _inputs.append(AGENT_TYPE_MAPPING[input_type](_input)) + + # Should not raise an error + outputs = self.tool(*inputs) + + if not isinstance(outputs, list): + outputs = [outputs] + + self.assertEqual(len(outputs), len(self.tool.outputs))