[Generate] correct encoder_outputs are passed without attention_mask (#14980)
* [Generate] correct encoder_outputs are passed without attention_mask * Apply suggestions from code review * up
This commit is contained in:
committed by
GitHub
parent
a1392883ce
commit
c043ce6cfd
@@ -1019,8 +1019,10 @@ class GenerationMixin:
|
|||||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||||
model_kwargs["use_cache"] = use_cache
|
model_kwargs["use_cache"] = use_cache
|
||||||
|
|
||||||
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, pad_token_id, eos_token_id
|
inputs_tensor, pad_token_id, eos_token_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||||
self.assertEqual(output_sequences.shape, (2, 5))
|
self.assertEqual(output_sequences.shape, (2, 5))
|
||||||
|
|
||||||
|
def test_generate_encoder_outputs_attention_mask(self):
|
||||||
|
input_values = floats_tensor((2, 250)).to(torch_device)
|
||||||
|
attention_mask = torch.ones_like(input_values)
|
||||||
|
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||||
|
model = model.to(torch_device)
|
||||||
|
|
||||||
|
encoder = model.get_encoder()
|
||||||
|
|
||||||
|
encoder_outputs = encoder(input_values)
|
||||||
|
|
||||||
|
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
|
||||||
|
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
|
||||||
|
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||||
|
|
||||||
|
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||||
|
|||||||
@@ -215,3 +215,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
filename = ds[40]["file"]
|
filename = ds[40]["file"]
|
||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torchaudio
|
||||||
|
def test_speech_to_text_leveraged(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="patrickvonplaten/wav2vec2-2-bart-base",
|
||||||
|
feature_extractor="patrickvonplaten/wav2vec2-2-bart-base",
|
||||||
|
tokenizer=AutoTokenizer.from_pretrained("patrickvonplaten/wav2vec2-2-bart-base"),
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
filename = ds[40]["file"]
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
|
||||||
|
|||||||
Reference in New Issue
Block a user