Generate: move prepare_inputs_for_generation in encoder-decoder llms (#34048)
This commit is contained in:
@@ -3841,6 +3841,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(model_inputs["input_ids"] is not None)
|
||||
self.assertTrue(model_inputs["inputs_embeds"] is None)
|
||||
|
||||
def test_prepare_inputs_for_generation_encoder_decoder_llm(self):
|
||||
"""
|
||||
Same as `test_prepare_inputs_for_generation_decoder_llm` but for encoder-decoder models. Main difference: we
|
||||
should look for `decoder_input_ids`, instead of `input_ids`.
|
||||
"""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = model.to(torch_device)
|
||||
|
||||
# 1. Sanity check: the model's `prepare_inputs_for_generation` comes from `GenerationMixin`
|
||||
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
|
||||
|
||||
# 2. If we pass input ids by themselves, we should get back the same input ids -- with the encoder-decoder key
|
||||
decoder_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
|
||||
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids)
|
||||
self.assertTrue(torch.all(model_inputs["decoder_input_ids"] == decoder_input_ids))
|
||||
|
||||
# 3. If we pass the attention mask too, we will get back the attention mask. Encoder-decoder models usually
|
||||
# don't use `position_ids`
|
||||
decoder_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
decoder_input_ids, decoder_attention_mask=decoder_attention_mask
|
||||
)
|
||||
self.assertTrue(torch.all(model_inputs["decoder_attention_mask"] == decoder_attention_mask))
|
||||
self.assertTrue("position_ids" not in model_inputs)
|
||||
|
||||
# 4. `use_cache` (and other kwargs, like the encoder outputs) are forwarded
|
||||
self.assertFalse("use_cache" in model_inputs) # From the previous input, there is no `use_cache`
|
||||
model_inputs = model.prepare_inputs_for_generation(decoder_input_ids, use_cache=True, encoder_outputs="foo")
|
||||
self.assertTrue(model_inputs["use_cache"] is True)
|
||||
self.assertTrue(model_inputs["encoder_outputs"] == "foo")
|
||||
# See the decoder-only test for more corner cases. The code is the same, so we don't repeat it here.
|
||||
|
||||
def test_generate_compile_fullgraph_tiny(self):
|
||||
"""
|
||||
Tests that we can call end-to-end generation with a tiny model (i.e. doesn't crash)
|
||||
|
||||
Reference in New Issue
Block a user