From ecb4662d17d6f1a32e37607902b00c6983772264 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 18 Mar 2022 10:02:12 +0100 Subject: [PATCH] Attention mask is important in the case of batching... (#16222) * Attention mask is important in the case of batching... * Improve the fix. * Making the sentence different enough that they exhibit different predictions. --- src/transformers/pipelines/base.py | 2 +- .../pipelines/token_classification.py | 1 - .../test_pipelines_token_classification.py | 17 +++++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 62e3abf37e..2f966337e6 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -149,7 +149,7 @@ def pad_collate_fn(tokenizer, feature_extractor): _padding_value = t_padding_value elif key in {"input_values", "pixel_values", "input_features"}: _padding_value = f_padding_value - elif key in {"p_mask"}: + elif key in {"p_mask", "special_tokens_mask"}: _padding_value = 1 elif key in {"attention_mask", "token_type_ids"}: _padding_value = 0 diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index d14616a9aa..56fe453dfb 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -192,7 +192,6 @@ class TokenClassificationPipeline(Pipeline): truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False model_inputs = self.tokenizer( sentence, - return_attention_mask=False, return_tensors=self.framework, truncation=truncation, return_special_tokens_mask=True, diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 94ac7a19ce..26cfa0d3be 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -649,6 +649,23 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ], ) + # Batch size does not affect outputs (attention_mask are required) + sentences = ["This is a test !", "Another test this is with longer sentence"] + outputs = token_classifier(sentences) + outputs_batched = token_classifier(sentences, batch_size=2) + # Batching does not make a difference in predictions + self.assertEqual(nested_simplify(outputs_batched), nested_simplify(outputs)) + self.assertEqual( + nested_simplify(outputs_batched), + [ + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + [], + ], + ) + @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"