[generate] fix breaking change for patch (#29976)

* fix bug and add tests

* nit

* otherway to get the cur len instead of attention mask

* more places where this might have been broken

* nit

* oups

* inputs_embeds vs input_embeds

* test generated outptus

* style

* nit

* fix

* skip failing biogpt
This commit is contained in:
Arthur
2024-04-02 09:51:45 +02:00
committed by GitHub
parent 096f304695
commit 83b26dd79d
3 changed files with 25 additions and 0 deletions

View File

@@ -717,6 +717,19 @@ class GenerationTesterMixin:
)
self.assertTrue(output_generate.shape[-1] == max_length)
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
input_embeds = model.get_input_embeddings()(input_ids)
beam_kwargs.update({"inputs_embeds": input_embeds})
output_generate2 = self._beam_sample_generate(
model=model,
input_ids=None,
attention_mask=attention_mask,
max_length=max_length,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:

View File

@@ -414,6 +414,10 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("The `input_embeds` when fed don't produce the same results.")
def test_beam_sample_generate(self):
pass
@require_torch
class BioGptModelIntegrationTest(unittest.TestCase):