Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)
* Some work to fix the behaviour of DefaultArgumentHandler by removing it. * Fixing specific pipelines argument checking.
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
|
||||
|
||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "fill-mask"
|
||||
pipeline_loading_kwargs = {"topk": 2}
|
||||
pipeline_loading_kwargs = {"top_k": 2}
|
||||
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
||||
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
||||
mandatory_keys = {"sequence", "score", "token"}
|
||||
@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
]
|
||||
expected_check_keys = ["sequence"]
|
||||
|
||||
@require_torch
|
||||
def test_torch_topk_deprecation(self):
|
||||
# At pipeline initialization only it was not enabled at pipeline
|
||||
# call site before
|
||||
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
|
||||
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask(self):
|
||||
valid_inputs = "My name is <mask>"
|
||||
nlp = pipeline(task="fill-mask", model=self.small_models[0])
|
||||
outputs = nlp(valid_inputs)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This passes
|
||||
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# This used to fail with `cannot mix args and kwargs`
|
||||
outputs = nlp(valid_inputs, something=False)
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_targets(self):
|
||||
valid_inputs = ["My name is <mask>"]
|
||||
@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
mono_result = nlp(valid_inputs[0], targets=valid_targets)
|
||||
|
||||
Reference in New Issue
Block a user