Test composition (#23214)
* Remove nestedness in tool config * Really do it * Use remote tools descriptions * Work * Clean up eval * Changes * Tools * Tools * tool * Fix everything * Use last result/assign for evaluation * Prompt * Remove hardcoded selection * Evaluation for chat agents * correct some spelling * Small fixes * Change summarization model (#23172) * Fix link displayed * Update description of the tool * Fixes in chat prompt * Custom tools, custom prompt * Tool clean up * save_pretrained and push_to_hub for tool * Fix init * Tests * Fix tests * Tool save/from_hub/push_to_hub and tool->load_tool * Clean push_to_hub and add app file * Custom inference API for endpoints too * Clean up * old remote tool and new remote tool * Make a requirements * return_code adds tool creation * Avoid redundancy between global variables * Remote tools can be loaded * Tests * Text summarization tests * Quality * Properly mark tests * Test the python interpreter * And the CI shall be green. * fix loading of additional tools * Work on RemoteTool and fix tests * General clean up * Guard imports * Fix tools * docs: Fix broken link in 'How to add a model...' (#23216) fix link * Get default endpoint from the Hub * Add guide * Simplify tool config * Docs * Some fixes * Docs * Docs * Docs * Fix code returned by agent * Try this * Match args with signature in remote tool * Should fix python interpreter for Python 3.8 * Fix push_to_hub for tools * Other fixes to push_to_hub * Add API doc page * Docs * Docs * Custom tools * Pin tensorflow-probability (#23220) * Pin tensorflow-probability * [all-test] * [all-test] Fix syntax for bash * PoC for some chaining API * Text to speech * J'ai pris des libertés * Rename * Basic python interpreter * Add agents * Quality * Add translation tool * temp * GenQA + LID + S2T * Quality + word missing in translation * Add open assistance, support f-strings in evaluate * captioning + s2t fixes * Style * Refactor descriptions and remove chain * Support errors and rename OpenAssistantAgent * Add setup * Deal with typos + example of inference API * Some rename + README * Fixes * Update prompt * Unwanted change * Make sure everyone has a default * One prompt to rule them all. * SD * Description * Clean up remote tools * More remote tools * Add option to return code and update doc * Image segmentation * ControlNet * Gradio demo * Diffusers protection * Lib protection * ControlNet description * Cleanup * Style * Remove accelerate and try to be reproducible * No randomness * Male Basic optional in token * Clean description * Better prompts * Fix args eval in interpreter * Add tool wrapper * Tool on the Hub * Style post-rebase * Big refactor of descriptions, batch generation and evaluation for agents * Make problems easier - interface to debug * More problems, add python primitives * Back to one prompt * Remove dict for translation * Be consistent * Add prompts * New version of the agent * Evaluate new agents * New endpoints agents * Make all tools a dict variable * Typo * Add problems * Add to big prompt * Harmonize * Add tools * New evaluation * Add more tools * Build prompt with tools descriptions * Tools on the Hub * Let's chat! * Cleanup * Temporary bs4 safeguard * Cache agents and clean up * Blank init * Fix evaluation for agents * New format for tools on the Hub * Add method to reset state * Remove nestedness in tool config * Really do it * Use remote tools descriptions * Work * Clean up eval * Changes * Tools * Tools * tool * Fix everything * Use last result/assign for evaluation * Prompt * Remove hardcoded selection * Evaluation for chat agents * correct some spelling * Small fixes * Change summarization model (#23172) * Fix link displayed * Update description of the tool * Fixes in chat prompt * Custom tools, custom prompt * Tool clean up * save_pretrained and push_to_hub for tool * Fix init * Tests * Fix tests * Tool save/from_hub/push_to_hub and tool->load_tool * Clean push_to_hub and add app file * Custom inference API for endpoints too * Clean up * old remote tool and new remote tool * Make a requirements * return_code adds tool creation * Avoid redundancy between global variables * Remote tools can be loaded * Tests * Text summarization tests * Quality * Properly mark tests * Test the python interpreter * And the CI shall be green. * Work on RemoteTool and fix tests * fix loading of additional tools * General clean up * Guard imports * Fix tools * Get default endpoint from the Hub * Simplify tool config * Add guide * Docs * Some fixes * Docs * Docs * Fix code returned by agent * Try this * Docs * Match args with signature in remote tool * Should fix python interpreter for Python 3.8 * Fix push_to_hub for tools * Other fixes to push_to_hub * Add API doc page * Fixes * Doc fixes * Docs * Fix audio * Custom tools * Audio fix * Improve custom tools docstring * Docstrings * Trigger CI * Mode docstrings * More docstrings * Improve custom tools * Fix for remote tools * Style * Fix repo consistency * Quality * Tip * Cleanup on doc * Cleanup toc * Add disclaimer for starcoder vs openai * Remove disclaimer * Small fixed in the prompts * 4.29 * Update src/transformers/tools/agents.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Complete documentation * Small fixes * Agent evaluation * Note about gradio-tools & LC * Clean up agents and prompt * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Note about gradio-tools & LC * Add copyrights and address review comments * Quality * Add all language codes * Add remote tool tests * Move custom prompts to other docs * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * TTS tests * Quality --------- Co-authored-by: Lysandre <hi@lyand.re> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Co-authored-by: Connor Henderson <connor.henderson@talkiatry.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
Sylvain Gugger
parent
d5e1c98120
commit
2a2be57697
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
57
tests/tools/test_document_question_answering.py
Normal file
57
tests/tools/test_document_question_answering.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("document-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
53
tests/tools/test_image_captioning.py
Normal file
53
tests/tools/test_image_captioning.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-captioning")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-captioning", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
53
tests/tools/test_image_question_answering.py
Normal file
53
tests/tools/test_image_question_answering.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
53
tests/tools/test_image_segmentation.py
Normal file
53
tests/tools/test_image_segmentation.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-segmentation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image=image, prompt="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image=image, prompt="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
124
tests/tools/test_python_interpreter.py
Normal file
124
tests/tools/test_python_interpreter.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
# Won't work without the tool
|
||||
with CaptureStdout() as out:
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result is None
|
||||
assert "tried to execute add_two" in out.out
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
38
tests/tools/test_speech_to_text.py
Normal file
38
tests/tools/test_speech_to_text.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
43
tests/tools/test_text_classification.py
Normal file
43
tests/tools/test_text_classification.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-classification")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-classification", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
52
tests/tools/test_text_question_answering.py
Normal file
52
tests/tools/test_text_question_answering.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
64
tests/tools/test_text_summarization.py
Normal file
64
tests/tools/test_text_summarization.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("summarization")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("summarization", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
54
tests/tools/test_text_to_speech.py
Normal file
54
tests/tools/test_text_to_speech.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-to-speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
||||
100
tests/tools/test_tools_common.py
Normal file
100
tests/tools/test_tools_common.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, str):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, Image.Image):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
@is_tool_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
||||
53
tests/tools/test_translation.py
Normal file
53
tests/tools/test_translation.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
Reference in New Issue
Block a user