From c043ce6cfd831b96dc3409d86e9b204b7600afee Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 30 Dec 2021 10:16:03 +0100 Subject: [PATCH] [Generate] correct encoder_outputs are passed without attention_mask (#14980) * [Generate] correct encoder_outputs are passed without attention_mask * Apply suggestions from code review * up --- src/transformers/generation_utils.py | 6 ++++-- tests/test_generation_utils.py | 16 ++++++++++++++++ ...st_pipelines_automatic_speech_recognition.py | 17 +++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6990924e42..1906235ae1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1019,8 +1019,10 @@ class GenerationMixin: model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["use_cache"] = use_cache - has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) - if model_kwargs.get("attention_mask", None) is None and has_attention_mask: + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, pad_token_id, eos_token_id ) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 040350649e..1ffb02bd4b 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) self.assertEqual(output_sequences.shape, (2, 5)) + + def test_generate_encoder_outputs_attention_mask(self): + input_values = floats_tensor((2, 250)).to(torch_device) + attention_mask = torch.ones_like(input_values) + model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder") + model = model.to(torch_device) + + encoder = model.get_encoder() + + encoder_outputs = encoder(input_values) + + output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu() + output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask) + output_sequences_with_mask = output_sequences_with_mask.cpu() + + self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist()) diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index 5d3e9cdc17..f3c105a74f 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -215,3 +215,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel filename = ds[40]["file"] output = speech_recognizer(filename) self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."}) + + @slow + @require_torch + @require_torchaudio + def test_speech_to_text_leveraged(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="patrickvonplaten/wav2vec2-2-bart-base", + feature_extractor="patrickvonplaten/wav2vec2-2-bart-base", + tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"), + framework="pt", + ) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": "a man said to the universe sir i exist"})