Adding skip_special_tokens=True to FillMaskPipeline (#9783)
* We most likely don't want special tokens in this output. * Adding `skip_special_tokens=True` to FillMaskPipeline - It's backward incompatible. - It makes for sense for pipelines to remove references to special_tokens (all of the other pipelines do that). - Keeping special tokens makes it hard for users to actually remove them because all models have different tokens (<s>, <cls>, [CLS], ....) * Fixing `token_str` in the same vein, and actually fix the tests too !
This commit is contained in:
@@ -179,10 +179,10 @@ class FillMaskPipeline(Pipeline):
|
||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||
result.append(
|
||||
{
|
||||
"sequence": self.tokenizer.decode(tokens),
|
||||
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
|
||||
"score": v,
|
||||
"token": p,
|
||||
"token_str": self.tokenizer.convert_ids_to_tokens(p),
|
||||
"token_str": self.tokenizer.decode(p),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -22,31 +22,26 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
EXPECTED_FILL_MASK_RESULT = [
|
||||
[
|
||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"},
|
||||
{"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
|
||||
{"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
|
||||
],
|
||||
[
|
||||
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
||||
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
||||
{
|
||||
"sequence": "The largest city in France is Paris",
|
||||
"score": 0.2510891854763031,
|
||||
"token": 2201,
|
||||
"token_str": " Paris",
|
||||
},
|
||||
{
|
||||
"sequence": "The largest city in France is Lyon",
|
||||
"score": 0.21418564021587372,
|
||||
"token": 12790,
|
||||
"token_str": " Lyon",
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
EXPECTED_FILL_MASK_TARGET_RESULT = [
|
||||
[
|
||||
{
|
||||
"sequence": "<s>My name is Patrick</s>",
|
||||
"score": 0.004992353264242411,
|
||||
"token": 3499,
|
||||
"token_str": "ĠPatrick",
|
||||
},
|
||||
{
|
||||
"sequence": "<s>My name is Clara</s>",
|
||||
"score": 0.00019297805556561798,
|
||||
"token": 13606,
|
||||
"token_str": "ĠClara",
|
||||
},
|
||||
]
|
||||
]
|
||||
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]
|
||||
|
||||
|
||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
@@ -138,7 +133,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
||||
for r, e in zip(result, expected):
|
||||
self.assertEqual(r["sequence"], e["sequence"])
|
||||
self.assertEqual(r["token_str"], e["token_str"])
|
||||
self.assertEqual(r["token"], e["token"])
|
||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
@@ -162,7 +161,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
||||
for r, e in zip(result, expected):
|
||||
self.assertEqual(r["sequence"], e["sequence"])
|
||||
self.assertEqual(r["token_str"], e["token_str"])
|
||||
self.assertEqual(r["token"], e["token"])
|
||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
@@ -197,7 +200,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
||||
for r, e in zip(result, expected):
|
||||
self.assertEqual(r["sequence"], e["sequence"])
|
||||
self.assertEqual(r["token_str"], e["token_str"])
|
||||
self.assertEqual(r["token"], e["token"])
|
||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
@@ -221,7 +228,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
||||
for r, e in zip(result, expected):
|
||||
self.assertEqual(r["sequence"], e["sequence"])
|
||||
self.assertEqual(r["token_str"], e["token_str"])
|
||||
self.assertEqual(r["token"], e["token"])
|
||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
|
||||
Reference in New Issue
Block a user