From 3c2383b1c6eb860c0511d081e670d1782cd66b8d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 23 Aug 2023 19:17:01 +0100 Subject: [PATCH] Generate: general test for decoder-only generation from `inputs_embeds` (#25687) Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/generation/test_utils.py | 86 ++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ea9ab3c753..f983f527a8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1750,6 +1750,56 @@ class GenerationTesterMixin: past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) ) + def test_generate_from_inputs_embeds_decoder_only(self): + # When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids` + # if fails, you should probably update the `prepare_inputs_for_generation` function + for model_class in self.all_generative_model_classes: + config, input_ids, _, _ = self._get_input_ids_and_config() + + # Ignore: + # a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids, + # which would cause a mismatch), + config.pad_token_id = config.eos_token_id = -1 + # b) embedding scaling, the scaling factor applied after embeding from input_ids (requires knowledge of the + # variable that holds the scaling factor, which is model-dependent) + if hasattr(config, "scale_embedding"): + config.scale_embedding = False + + # This test is for decoder-only models (encoder-decoder models have native input embeddings support in the + # decoder) + if config.is_encoder_decoder: + continue + + # Skip models without explicit support + model = model_class(config).to(torch_device).eval() + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + continue + + # Traditional way of generating text + outputs_from_ids = model.generate(input_ids) + self.assertEqual(outputs_from_ids.shape, (2, 20)) + + # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) + self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) + + # But if we pass different inputs_embeds, we should get different outputs + torch.manual_seed(0) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) + + # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1] + ) + self.assertListEqual( + outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids[:, 1:].tolist(), + ) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences @@ -2773,42 +2823,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) - def test_generate_from_inputs_embeds_decoder_only(self): - # PT-only test: TF doesn't have a model with support to generate from input embeds (yet ;)) - # Note: the model must support generation from input embeddings - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) - tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") - model.config.pad_token_id = tokenizer.eos_token_id - - text = "Hello world" - tokenized_inputs = tokenizer([text, text], return_tensors="pt") - input_ids = tokenized_inputs.input_ids.to(torch_device) - - # Traditional way of generating text - outputs_from_ids = model.generate(input_ids) - self.assertEqual(outputs_from_ids.shape, (2, 20)) - - # Same thing, but from input embeddings - inputs_embeds = model.transformer.wte(input_ids) - outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) - self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) - - # But if we pass different inputs_embeds, we should get different outputs - torch.manual_seed(0) - random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) - - # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same - outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1] - ) - self.assertListEqual( - outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids[:, 1:].tolist(), - ) - def test_model_kwarg_encoder_signature_filtering(self): # Has TF equivalent: ample use of framework-specific code bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")