[Generate] correct encoder_outputs are passed without attention_mask (#14980)
* [Generate] correct encoder_outputs are passed without attention_mask * Apply suggestions from code review * up
This commit is contained in:
committed by
GitHub
parent
a1392883ce
commit
c043ce6cfd
@@ -1019,8 +1019,10 @@ class GenerationMixin:
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
||||
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, pad_token_id, eos_token_id
|
||||
)
|
||||
|
||||
@@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_generate_encoder_outputs_attention_mask(self):
|
||||
input_values = floats_tensor((2, 250)).to(torch_device)
|
||||
attention_mask = torch.ones_like(input_values)
|
||||
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
|
||||
encoder = model.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_values)
|
||||
|
||||
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
|
||||
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||
|
||||
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||
|
||||
@@ -215,3 +215,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_speech_to_text_leveraged(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="patrickvonplaten/wav2vec2-2-bart-base",
|
||||
feature_extractor="patrickvonplaten/wav2vec2-2-bart-base",
|
||||
tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"),
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||
|
||||
Reference in New Issue
Block a user