FillMaskPipeline: support passing top_k on __call__ (#7971)
* FillMaskPipeline: support passing top_k on __call__ Also move from topk to top_k * migrate to new param name in tests * Review from @sgugger
This commit is contained in:
@@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
@@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
@@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
@@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
]
|
||||
valid_targets = [" Patrick", " Clara"]
|
||||
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
valid_inputs,
|
||||
|
||||
Reference in New Issue
Block a user