Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)

* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
This commit is contained in:
Nicolas Patry
2020-11-02 12:33:50 +01:00
committed by GitHub
parent 00cc2d1df2
commit 84caa23301
4 changed files with 129 additions and 148 deletions

View File

@@ -1,8 +1,9 @@
import unittest
from typing import List, Optional
from transformers import is_tf_available, is_torch_available, pipeline
from transformers.pipelines import DefaultArgumentHandler, Pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
self.assertRaises(Exception, nlp, self.invalid_inputs)
@is_pipeline_test
class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.handler = DefaultArgumentHandler()
def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2)
multi_data = {
"data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"],
}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4)
def test_args(self):
mono_data = "This is a sample input"
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)