Fix generate with inputs_embeds as input (#32493)
* I think inputs_embeds has ndim == 3 * fix sequence length catch * add generate test * [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama * skip whisper * fix bart test * more fixes
This commit is contained in:
@@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# generate only works with input ids for bertforcausalLM
|
||||
pass
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user