Generate: move prepare_inputs_for_generation in encoder-decoder llms (#34048)

This commit is contained in:
Joao Gante
2024-10-11 16:11:18 +01:00
committed by GitHub
parent fd70464fa7
commit 37ac078535
25 changed files with 49 additions and 725 deletions

View File

@@ -3841,6 +3841,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(model_inputs["input_ids"] is not None)
self.assertTrue(model_inputs["inputs_embeds"] is None)
def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
"""
Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we
should look for `decoder_input_ids`, instead of `input_ids`.
"""
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
model = model.to(torch_device)
# 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin`
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
# 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key
decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids)
self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids))
# 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually
# don't use `position_ids`
decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
model_inputs = model.prepare_inputs_for_generation(
decoder_input_ids, decoder_attention_mask=decoder_attention_mask
)
self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask))
self.assertTrue("position_ids" not in model_inputs)
# 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded
self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache`
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo")
self.assertTrue(model_inputs["use_cache"] is True)
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.
def test_generate_compile_fullgraph_tiny(self):
"""
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)