[Generate] correct encoder_outputs are passed without attention_mask (#14980)
* [Generate] correct encoder_outputs are passed without attention_mask * Apply suggestions from code review * up
This commit is contained in:
committed by
GitHub
parent
a1392883ce
commit
c043ce6cfd
@@ -1887,3 +1887,19 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_generate_encoder_outputs_attention_mask(self):
|
||||
input_values = floats_tensor((2, 250)).to(torch_device)
|
||||
attention_mask = torch.ones_like(input_values)
|
||||
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
|
||||
encoder = model.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_values)
|
||||
|
||||
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs).cpu()
|
||||
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu()
|
||||
|
||||
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
|
||||
|
||||
Reference in New Issue
Block a user