Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)
* Some work to fix the behaviour of DefaultArgumentHandler by removing it. * Fixing specific pipelines argument checking.
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import is_tf_available, is_torch_available, pipeline
|
||||
from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||
|
||||
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
|
||||
|
||||
@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
|
||||
self.assertRaises(Exception, nlp, self.invalid_inputs)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.handler = DefaultArgumentHandler()
|
||||
|
||||
def test_kwargs_x(self):
|
||||
mono_data = {"X": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_kwargs_data(self):
|
||||
mono_data = {"data": "This is a sample input"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
def test_multi_kwargs(self):
|
||||
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||
mono_args = self.handler(**mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 2)
|
||||
|
||||
multi_data = {
|
||||
"data": ["This is a sample input", "This is a second sample input"],
|
||||
"test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||
}
|
||||
multi_args = self.handler(**multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 4)
|
||||
|
||||
def test_args(self):
|
||||
mono_data = "This is a sample input"
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
mono_data = ["This is a sample input"]
|
||||
mono_args = self.handler(mono_data)
|
||||
|
||||
self.assertTrue(isinstance(mono_args, list))
|
||||
self.assertEqual(len(mono_args), 1)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
|
||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
multi_args = self.handler(*multi_data)
|
||||
|
||||
self.assertTrue(isinstance(multi_args, list))
|
||||
self.assertEqual(len(multi_args), 2)
|
||||
# @is_pipeline_test
|
||||
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||
# def setUp(self) -> None:
|
||||
# self.handler = DefaultArgumentHandler()
|
||||
#
|
||||
# def test_kwargs_x(self):
|
||||
# mono_data = {"X": "This is a sample input"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# def test_kwargs_data(self):
|
||||
# mono_data = {"data": "This is a sample input"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# def test_multi_kwargs(self):
|
||||
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||
# mono_args = self.handler(**mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 2)
|
||||
#
|
||||
# multi_data = {
|
||||
# "data": ["This is a sample input", "This is a second sample input"],
|
||||
# "test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||
# }
|
||||
# multi_args = self.handler(**multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 4)
|
||||
#
|
||||
# def test_args(self):
|
||||
# mono_data = "This is a sample input"
|
||||
# mono_args = self.handler(mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# mono_data = ["This is a sample input"]
|
||||
# mono_args = self.handler(mono_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(mono_args, list))
|
||||
# self.assertEqual(len(mono_args), 1)
|
||||
#
|
||||
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
# multi_args = self.handler(multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
#
|
||||
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||
# multi_args = self.handler(*multi_data)
|
||||
#
|
||||
# self.assertTrue(isinstance(multi_args, list))
|
||||
# self.assertEqual(len(multi_args), 2)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
|
||||
|
||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "fill-mask"
|
||||
pipeline_loading_kwargs = {"topk": 2}
|
||||
pipeline_loading_kwargs = {"top_k": 2}
|
||||
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
||||
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
||||
mandatory_keys = {"sequence", "score", "token"}
|
||||
@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
]
|
||||
expected_check_keys = ["sequence"]
|
||||
|
||||
@require_torch
|
||||
def test_torch_topk_deprecation(self):
|
||||
# At pipeline initialization only it was not enabled at pipeline
|
||||
# call site before
|
||||
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
|
||||
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask(self):
|
||||
valid_inputs = "My name is <mask>"
|
||||
nlp = pipeline(task="fill-mask", model=self.small_models[0])
|
||||
outputs = nlp(valid_inputs)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This passes
|
||||
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This used to fail with `cannot mix args and kwargs`
|
||||
outputs = nlp(valid_inputs, something=False)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_targets(self):
|
||||
valid_inputs = ["My name is <mask>"]
|
||||
@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
mono_result = nlp(valid_inputs[0], targets=valid_targets)
|
||||
|
||||
@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
||||
sum = 0.0
|
||||
for score in result["scores"]:
|
||||
sum += score
|
||||
self.assertAlmostEqual(sum, 1.0)
|
||||
self.assertAlmostEqual(sum, 1.0, places=5)
|
||||
|
||||
def _test_entailment_id(self, nlp: Pipeline):
|
||||
config = nlp.model.config
|
||||
|
||||
Reference in New Issue
Block a user