Fix generate with inputs_embeds as input (#32493)
* I think inputs_embeds has ndim == 3 * fix sequence length catch * add generate test * [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama * skip whisper * fix bart test * more fixes
This commit is contained in:
@@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Decoder can't keep attention grads")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user