Temporarily revert the fill-mask improvements.
This commit is contained in:
@@ -78,8 +78,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_targets(self):
|
||||
valid_inputs = ["My name is <mask>"]
|
||||
# ' Sam' will yield a warning but work
|
||||
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
|
||||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
||||
invalid_targets = [[], [""], ""]
|
||||
for model_name in self.small_models:
|
||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||
@@ -90,34 +89,10 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
for targets in invalid_targets:
|
||||
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_targets_and_topk(self):
|
||||
model_name = self.small_models[0]
|
||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||
targets = [" Teven", "ĠPatrick", "ĠClara"]
|
||||
top_k = 2
|
||||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
||||
|
||||
self.assertEqual(len(outputs), 2)
|
||||
|
||||
@require_torch
|
||||
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
|
||||
model_name = self.small_models[0]
|
||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||
# String duplicates + id duplicates
|
||||
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
|
||||
top_k = 10
|
||||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
||||
|
||||
# The target list contains duplicates, so we can't output more
|
||||
# than them
|
||||
self.assertEqual(len(outputs), 3)
|
||||
|
||||
@require_tf
|
||||
def test_tf_fill_mask_with_targets(self):
|
||||
valid_inputs = ["My name is <mask>"]
|
||||
# ' Sam' will yield a warning but work
|
||||
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
|
||||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
||||
invalid_targets = [[], [""], ""]
|
||||
for model_name in self.small_models:
|
||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
|
||||
@@ -136,7 +111,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
"My name is <mask>",
|
||||
"The largest city in France is <mask>",
|
||||
]
|
||||
valid_targets = ["ĠPatrick", "ĠClara"]
|
||||
valid_targets = [" Patrick", " Clara"]
|
||||
for model_name in self.large_models:
|
||||
unmasker = pipeline(
|
||||
task="fill-mask",
|
||||
@@ -209,7 +184,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
"My name is <mask>",
|
||||
"The largest city in France is <mask>",
|
||||
]
|
||||
valid_targets = ["ĠPatrick", "ĠClara"]
|
||||
valid_targets = [" Patrick", " Clara"]
|
||||
for model_name in self.large_models:
|
||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user