From 65ee1a43e5aee0ebf34b4a32b5cf01ce58311445 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 Sep 2021 12:48:54 +0200 Subject: [PATCH] fixing BC in `fill-mask` (wasn't tested in theses test suites (#13540) apparently). --- src/transformers/pipelines/fill_mask.py | 5 ++++- tests/test_pipelines_fill_mask.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index bca61d144d..5392db979b 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -219,4 +219,7 @@ class FillMaskPipeline(Pipeline): - **token** (:obj:`int`) -- The predicted token id (to replace the masked one). - **token** (:obj:`str`) -- The predicted token (to replace the masked one). """ - return super().__call__(inputs, **kwargs) + outputs = super().__call__(inputs, **kwargs) + if isinstance(inputs, list) and len(inputs) == 1: + return outputs[0] + return outputs diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index 8f65417f47..fb48fe52cd 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -186,6 +186,18 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ], ) + outputs = fill_masker([f"This is a {tokenizer.mask_token}"]) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."]) self.assertEqual( outputs,