Fix generate with inputs_embeds as input (#32493)

* I think inputs_embeds has ndim == 3

* fix sequence length catch

* add generate test

* [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama

* skip whisper

* fix bart test

* more fixes
This commit is contained in:
Pablo Montalvo
2024-08-08 18:44:53 +02:00
committed by GitHub
parent b01f9c484c
commit 044281605f
23 changed files with 213 additions and 144 deletions

View File

@@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for bertforcausalLM
pass
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(