Fix generate with inputs_embeds as input (#32493)

* I think inputs_embeds has ndim == 3

* fix sequence length catch

* add generate test

* [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama

* skip whisper

* fix bart test

* more fixes
This commit is contained in:
Pablo Montalvo
2024-08-08 18:44:53 +02:00
committed by GitHub
parent b01f9c484c
commit 044281605f
23 changed files with 213 additions and 144 deletions

View File

@@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
# generate only works with input ids for whisper
pass
@unittest.skip(reason="Generate needs input ids")
def test_inputs_embeds_matches_input_ids_with_generate(self):
# generate only works with input ids for whisper
pass
@unittest.skip(reason="Decoder can't keep attention grads")
def test_retain_grad_hidden_states_attentions(self):
return