[Generate] correct encoder_outputs are passed without attention_mask (#14980)

* [Generate] correct encoder_outputs are passed without attention_mask

* Apply suggestions from code review

* up
This commit is contained in:
Patrick von Platen
2021-12-30 10:16:03 +01:00
committed by GitHub
parent a1392883ce
commit c043ce6cfd
3 changed files with 37 additions and 2 deletions

View File

@@ -1019,8 +1019,10 @@ class GenerationMixin:
model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache model_kwargs["use_cache"] = use_cache
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
if model_kwargs.get("attention_mask", None) is None and has_attention_mask: requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id inputs_tensor, pad_token_id, eos_token_id
) )

View File

@@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist()) self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5)) self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_encoder_outputs_attention_mask(self):
input_values = floats_tensor((2, 250)).to(torch_device)
attention_mask = torch.ones_like(input_values)
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
encoder = model.get_encoder()
encoder_outputs = encoder(input_values)
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
output_sequences_with_mask = output_sequences_with_mask.cpu()
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())

View File

@@ -215,3 +215,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
filename = ds[40]["file"] filename = ds[40]["file"]
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."}) self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
@slow
@require_torch
@require_torchaudio
def test_speech_to_text_leveraged(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="patrickvonplaten/wav2vec2-2-bart-base",
feature_extractor="patrickvonplaten/wav2vec2-2-bart-base",
tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"),
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})