Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)

* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
This commit is contained in:
Nicolas Patry
2020-11-02 12:33:50 +01:00
committed by GitHub
parent 00cc2d1df2
commit 84caa23301
4 changed files with 129 additions and 148 deletions

View File

@@ -1,8 +1,9 @@
import unittest
from typing import List, Optional
from transformers import is_tf_available, is_torch_available, pipeline
from transformers.pipelines import DefaultArgumentHandler, Pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
self.assertRaises(Exception, nlp, self.invalid_inputs)
@is_pipeline_test
class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.handler = DefaultArgumentHandler()
def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2)
multi_data = {
"data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"],
}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4)
def test_args(self):
mono_data = "This is a sample input"
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)

View File

@@ -1,5 +1,7 @@
import unittest
import pytest
from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow
@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "fill-mask"
pipeline_loading_kwargs = {"topk": 2}
pipeline_loading_kwargs = {"top_k": 2}
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
mandatory_keys = {"sequence", "score", "token"}
@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
]
expected_check_keys = ["sequence"]
@require_torch
def test_torch_topk_deprecation(self):
# At pipeline initialization only it was not enabled at pipeline
# call site before
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
@require_torch
def test_torch_fill_mask(self):
valid_inputs = "My name is <mask>"
nlp = pipeline(task="fill-mask", model=self.small_models[0])
outputs = nlp(valid_inputs)
self.assertIsInstance(outputs, list)
# This passes
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
self.assertIsInstance(outputs, list)
# This used to fail with `cannot mix args and kwargs`
outputs = nlp(valid_inputs, something=False)
self.assertIsInstance(outputs, list)
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
top_k=2,
)
mono_result = nlp(valid_inputs[0], targets=valid_targets)

View File

@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
sum = 0.0
for score in result["scores"]:
sum += score
self.assertAlmostEqual(sum, 1.0)
self.assertAlmostEqual(sum, 1.0, places=5)
def _test_entailment_id(self, nlp: Pipeline):
config = nlp.model.config