Adding skip_special_tokens=True to FillMaskPipeline (#9783)
* We most likely don't want special tokens in this output. * Adding `skip_special_tokens=True` to FillMaskPipeline - It's backward incompatible. - It makes for sense for pipelines to remove references to special_tokens (all of the other pipelines do that). - Keeping special tokens makes it hard for users to actually remove them because all models have different tokens (<s>, <cls>, [CLS], ....) * Fixing `token_str` in the same vein, and actually fix the tests too !
This commit is contained in:
@@ -179,10 +179,10 @@ class FillMaskPipeline(Pipeline):
|
|||||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||||
result.append(
|
result.append(
|
||||||
{
|
{
|
||||||
"sequence": self.tokenizer.decode(tokens),
|
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
|
||||||
"score": v,
|
"score": v,
|
||||||
"token": p,
|
"token": p,
|
||||||
"token_str": self.tokenizer.convert_ids_to_tokens(p),
|
"token_str": self.tokenizer.decode(p),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,31 +22,26 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
|||||||
|
|
||||||
EXPECTED_FILL_MASK_RESULT = [
|
EXPECTED_FILL_MASK_RESULT = [
|
||||||
[
|
[
|
||||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
{"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
|
||||||
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"},
|
{"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
{
|
||||||
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
"sequence": "The largest city in France is Paris",
|
||||||
|
"score": 0.2510891854763031,
|
||||||
|
"token": 2201,
|
||||||
|
"token_str": " Paris",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is Lyon",
|
||||||
|
"score": 0.21418564021587372,
|
||||||
|
"token": 12790,
|
||||||
|
"token_str": " Lyon",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
EXPECTED_FILL_MASK_TARGET_RESULT = [
|
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]
|
||||||
[
|
|
||||||
{
|
|
||||||
"sequence": "<s>My name is Patrick</s>",
|
|
||||||
"score": 0.004992353264242411,
|
|
||||||
"token": 3499,
|
|
||||||
"token_str": "ĠPatrick",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"sequence": "<s>My name is Clara</s>",
|
|
||||||
"score": 0.00019297805556561798,
|
|
||||||
"token": 13606,
|
|
||||||
"token_str": "ĠClara",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
@@ -138,7 +133,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
||||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
for r, e in zip(result, expected):
|
||||||
|
self.assertEqual(r["sequence"], e["sequence"])
|
||||||
|
self.assertEqual(r["token_str"], e["token_str"])
|
||||||
|
self.assertEqual(r["token"], e["token"])
|
||||||
|
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
if isinstance(multi_result[0], list):
|
||||||
multi_result = multi_result[0]
|
multi_result = multi_result[0]
|
||||||
@@ -162,7 +161,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
||||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
for r, e in zip(result, expected):
|
||||||
|
self.assertEqual(r["sequence"], e["sequence"])
|
||||||
|
self.assertEqual(r["token_str"], e["token_str"])
|
||||||
|
self.assertEqual(r["token"], e["token"])
|
||||||
|
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
if isinstance(multi_result[0], list):
|
||||||
multi_result = multi_result[0]
|
multi_result = multi_result[0]
|
||||||
@@ -197,7 +200,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
||||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
for r, e in zip(result, expected):
|
||||||
|
self.assertEqual(r["sequence"], e["sequence"])
|
||||||
|
self.assertEqual(r["token_str"], e["token_str"])
|
||||||
|
self.assertEqual(r["token"], e["token"])
|
||||||
|
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
if isinstance(multi_result[0], list):
|
||||||
multi_result = multi_result[0]
|
multi_result = multi_result[0]
|
||||||
@@ -221,7 +228,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
||||||
self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result]))
|
for r, e in zip(result, expected):
|
||||||
|
self.assertEqual(r["sequence"], e["sequence"])
|
||||||
|
self.assertEqual(r["token_str"], e["token_str"])
|
||||||
|
self.assertEqual(r["token"], e["token"])
|
||||||
|
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
if isinstance(multi_result[0], list):
|
||||||
multi_result = multi_result[0]
|
multi_result = multi_result[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user