[generation] bring back tests on vision models (#38603)
* bring back geenration tests on VLMs * remove head mask tests overwritten
This commit is contained in:
committed by
GitHub
parent
90c4b90a10
commit
dbfc79c17c
@@ -774,6 +774,7 @@ class Blip2TextModelTester:
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
is_encoder_decoder=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -795,6 +796,9 @@ class Blip2ModelTester:
|
||||
self.text_model_tester = Blip2TextModelTester(parent, **text_kwargs)
|
||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
||||
self.encoder_seq_length = (
|
||||
self.text_model_tester.encoder_seq_length + num_query_tokens
|
||||
) # need enc seq_length for gen tests
|
||||
self.is_training = is_training
|
||||
self.num_query_tokens = num_query_tokens
|
||||
|
||||
@@ -859,11 +863,9 @@ class Blip2ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||
additional_model_inputs = ["input_ids", "decoder_input_ids"]
|
||||
# Doesn't run generation tests. TODO: fix generation tests for Blip2ForConditionalGeneration
|
||||
all_generative_model_classes = ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": Blip2Model,
|
||||
|
||||
Reference in New Issue
Block a user