From fa4bdb0a4060fd7a78bcad90dcc96a645ce11d31 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 Feb 2023 17:04:49 +0000 Subject: [PATCH] Generate: correct default model input creation for decoder-only models (#21580) --- src/transformers/generation/tf_utils.py | 31 +++++++++++++----- src/transformers/generation/utils.py | 32 +++++++++++++------ tests/generation/test_utils.py | 35 +++++++++++++++++++++ tests/models/blip_2/test_modeling_blip_2.py | 28 +++++++++++++++++ 4 files changed, 109 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index 63412879b8..1b18ed1c83 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -845,8 +845,7 @@ class TFGenerationMixin: model_kwargs=model_kwargs, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 7. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = shape_list(input_ids)[-1] @@ -1214,20 +1213,34 @@ class TFGenerationMixin: "doesn't have its forwarding implemented. See the GPT2 implementation for an example " "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0] + ) else: if inputs is not None: raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - if inputs is None: - inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) + inputs = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs.get("encoder_outputs") + ) return inputs, input_name, model_kwargs - def _prepare_input_ids_for_generation( - self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[tf.Tensor] = None, + bos_token_id: Optional[int] = None, + encoder_outputs: Optional[ModelOutput] = None, + batch_size: Optional[int] = None, ) -> tf.Tensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.shape[:-1] @@ -1235,7 +1248,9 @@ class TFGenerationMixin: if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - return tf.ones((1, 1), dtype=tf.int32) * bos_token_id + + batch_size = batch_size if batch_size is not None else 1 + return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id @staticmethod def _extract_past_from_model_output(outputs: ModelOutput): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 274d4a5554..014c66166f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -541,15 +541,20 @@ class GenerationMixin: "doesn't have its forwarding implemented. See the GPT2 implementation for an example " "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0] + ) else: if inputs is not None: raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" + inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. if `inputs` is still None, try to create `input_ids` from BOS token - if inputs is None: - inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) - + inputs = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs.get("encoder_outputs") + ) return inputs, input_name, model_kwargs def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: @@ -558,9 +563,17 @@ class GenerationMixin: """ return logits - def _prepare_input_ids_for_generation( - self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + encoder_outputs: Optional[ModelOutput] = None, + batch_size: Optional[int] = None, ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.size()[:-1] @@ -568,7 +581,9 @@ class GenerationMixin: if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") - return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id + + batch_size = batch_size if batch_size is not None else 1 + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( self, @@ -1258,8 +1273,7 @@ class GenerationMixin: device=inputs_tensor.device, ) else: - # if decoder-only then inputs_tensor has to be `input_ids` - input_ids = inputs_tensor + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5dcc1472c4..88559e58b6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2488,3 +2488,38 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi eos_token_id = [846, 198] generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) self.assertTrue(expectation == len(generated_tokens[0])) + + def test_generate_from_inputs_embeds_decoder_only(self): + # Note: the model must support generation from input embeddings + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text, text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_from_ids = model.generate(input_ids) + self.assertEqual(outputs_from_ids.shape, (2, 20)) + + # Same thing, but from input embeddings + inputs_embeds = model.transformer.wte(input_ids) + outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds) + self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) + + # But if we pass different inputs_embeds, we should get different outputs + torch.manual_seed(0) + random_embeds = torch.rand_like(inputs_embeds) + outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) + + # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, max_new_tokens=20 - inputs_embeds.shape[1] + ) + self.assertListEqual( + outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids[:, 1:].tolist(), + ) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 40f64e971a..ef3bc18453 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -797,6 +797,20 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ) self.assertEqual(generated_text, "it's not a city, it's a beach") + def test_inference_opt_batched(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") + model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(torch_device) + + # prepare image + image = prepare_img() + inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs) + + # Test output + self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + self.assertEqual(predictions[1].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]) + def test_inference_t5(self): processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") model = Blip2ForConditionalGeneration.from_pretrained( @@ -827,3 +841,17 @@ class Blip2ModelIntegrationTest(unittest.TestCase): [0, 3, 7, 152, 67, 839, 1], ) self.assertEqual(generated_text, "san diego") + + def test_inference_t5_batched(self): + processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") + model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(torch_device) + + # prepare image + image = prepare_img() + inputs = processor(images=[image, image], return_tensors="pt").to(torch_device) + + predictions = model.generate(**inputs) + + # Test output + self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1]) + self.assertEqual(predictions[1].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])