From 781e4b1384e86f2d8012072a6557a663d2308704 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 26 Jan 2021 10:06:28 +0100 Subject: [PATCH] Adding `skip_special_tokens=True` to FillMaskPipeline (#9783) * We most likely don't want special tokens in this output. * Adding `skip_special_tokens=True` to FillMaskPipeline - It's backward incompatible. - It makes for sense for pipelines to remove references to special_tokens (all of the other pipelines do that). - Keeping special tokens makes it hard for users to actually remove them because all models have different tokens (, , [CLS], ....) * Fixing `token_str` in the same vein, and actually fix the tests too ! --- src/transformers/pipelines/fill_mask.py | 4 +- tests/test_pipelines_fill_mask.py | 59 +++++++++++++++---------- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index 8da7f059db..251c7f0973 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -179,10 +179,10 @@ class FillMaskPipeline(Pipeline): tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] result.append( { - "sequence": self.tokenizer.decode(tokens), + "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), "score": v, "token": p, - "token_str": self.tokenizer.convert_ids_to_tokens(p), + "token_str": self.tokenizer.decode(p), } ) diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index f087ed2135..f86fc9c3d1 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -22,31 +22,26 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin EXPECTED_FILL_MASK_RESULT = [ [ - {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"}, - {"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"}, + {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"}, + {"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"}, ], [ - {"sequence": "The largest city in France is Paris", "score": 0.3185044229030609, "token": 2201}, - {"sequence": "The largest city in France is Lyon", "score": 0.21112334728240967, "token": 12790}, + { + "sequence": "The largest city in France is Paris", + "score": 0.2510891854763031, + "token": 2201, + "token_str": " Paris", + }, + { + "sequence": "The largest city in France is Lyon", + "score": 0.21418564021587372, + "token": 12790, + "token_str": " Lyon", + }, ], ] -EXPECTED_FILL_MASK_TARGET_RESULT = [ - [ - { - "sequence": "My name is Patrick", - "score": 0.004992353264242411, - "token": 3499, - "token_str": "ĠPatrick", - }, - { - "sequence": "My name is Clara", - "score": 0.00019297805556561798, - "token": 13606, - "token_str": "ĠClara", - }, - ] -] +EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]] class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): @@ -138,7 +133,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): self.assertIsInstance(multi_result[0], (dict, list)) for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT): - self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) + for r, e in zip(result, expected): + self.assertEqual(r["sequence"], e["sequence"]) + self.assertEqual(r["token_str"], e["token_str"]) + self.assertEqual(r["token"], e["token"]) + self.assertAlmostEqual(r["score"], e["score"], places=3) if isinstance(multi_result[0], list): multi_result = multi_result[0] @@ -162,7 +161,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): self.assertIsInstance(multi_result[0], (dict, list)) for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT): - self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) + for r, e in zip(result, expected): + self.assertEqual(r["sequence"], e["sequence"]) + self.assertEqual(r["token_str"], e["token_str"]) + self.assertEqual(r["token"], e["token"]) + self.assertAlmostEqual(r["score"], e["score"], places=3) if isinstance(multi_result[0], list): multi_result = multi_result[0] @@ -197,7 +200,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): self.assertIsInstance(multi_result[0], (dict, list)) for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT): - self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) + for r, e in zip(result, expected): + self.assertEqual(r["sequence"], e["sequence"]) + self.assertEqual(r["token_str"], e["token_str"]) + self.assertEqual(r["token"], e["token"]) + self.assertAlmostEqual(r["score"], e["score"], places=3) if isinstance(multi_result[0], list): multi_result = multi_result[0] @@ -221,7 +228,11 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): self.assertIsInstance(multi_result[0], (dict, list)) for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT): - self.assertEqual(set([o["sequence"] for o in result]), set([o["sequence"] for o in result])) + for r, e in zip(result, expected): + self.assertEqual(r["sequence"], e["sequence"]) + self.assertEqual(r["token_str"], e["token_str"]) + self.assertEqual(r["token"], e["token"]) + self.assertAlmostEqual(r["score"], e["score"], places=3) if isinstance(multi_result[0], list): multi_result = multi_result[0]