From 83b26dd79d5640dda9f50fafced4da7d5b38d818 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:51:45 +0200 Subject: [PATCH] [`generate`] fix breaking change for patch (#29976) * fix bug and add tests * nit * otherway to get the cur len instead of attention mask * more places where this might have been broken * nit * oups * inputs_embeds vs input_embeds * test generated outptus * style * nit * fix * skip failing biogpt --- src/transformers/generation/utils.py | 8 ++++++++ tests/generation/test_utils.py | 13 +++++++++++++ tests/models/biogpt/test_modeling_biogpt.py | 4 ++++ 3 files changed, 25 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a958c8c86a..cb3ac0ff1d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3034,6 +3034,8 @@ class GenerationMixin: num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: @@ -3437,6 +3439,8 @@ class GenerationMixin: num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) # init attention / hidden states / scores tuples @@ -3795,6 +3799,8 @@ class GenerationMixin: device = input_ids.device batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if return_dict_in_generate and output_scores: @@ -4211,6 +4217,8 @@ class GenerationMixin: num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) if num_beams * batch_size != batch_beam_size: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5c73e92a77..b346b745d8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -717,6 +717,19 @@ class GenerationTesterMixin: ) self.assertTrue(output_generate.shape[-1] == max_length) + if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters): + input_embeds = model.get_input_embeddings()(input_ids) + beam_kwargs.update({"inputs_embeds": input_embeds}) + output_generate2 = self._beam_sample_generate( + model=model, + input_ids=None, + attention_mask=attention_mask, + max_length=max_length, + beam_kwargs=beam_kwargs, + logits_warper_kwargs=logits_warper_kwargs, + ) + + torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2) def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index 1055288e5c..58dd39e86a 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -414,6 +414,10 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @unittest.skip("The `input_embeds` when fed don't produce the same results.") + def test_beam_sample_generate(self): + pass + @require_torch class BioGptModelIntegrationTest(unittest.TestCase):