VLMs: enable generation tests (#33533)
* add tests * fix whisper * update * nit * add qwen2-vl * more updates! * better this way * fix this one * fix more tests * fix final tests, hope so * fix led * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * pr comments * not pass pixels and extra for low-mem tests, very flaky because of visio tower --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e40bb4845e
commit
d7975a5874
@@ -289,7 +289,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.attention_type = "original_full"
|
||||
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
input_ids = inputs_dict.pop(self.input_name)
|
||||
_ = inputs_dict.pop("attention_mask", None)
|
||||
_ = inputs_dict.pop("decoder_input_ids", None)
|
||||
_ = inputs_dict.pop("decoder_attention_mask", None)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
@@ -300,7 +303,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BigBirdPegasusModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user