VLM: enable skipped tests (#35746)

* fix cached tests

* fix some tests

* fix pix2struct

* fix
This commit is contained in:
Raushan Turganbay
2025-02-12 12:55:46 +01:00
committed by GitHub
parent d6897b46bd
commit 8fc6ecba4f
10 changed files with 216 additions and 20 deletions

View File

@@ -516,7 +516,7 @@ class GenerationTesterMixin:
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -651,7 +651,7 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -989,7 +989,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1018,7 +1018,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
if self.has_attentions:
@@ -1060,7 +1060,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1179,6 +1179,10 @@ class GenerationTesterMixin:
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1187,7 +1191,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1254,6 +1258,10 @@ class GenerationTesterMixin:
"seamlessm4t",
"clvp",
"fuyu",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1262,7 +1270,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1368,6 +1376,10 @@ class GenerationTesterMixin:
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1376,7 +1388,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1570,7 +1582,7 @@ class GenerationTesterMixin:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
model = model_class(config).to(torch_device)
@@ -1605,7 +1617,14 @@ class GenerationTesterMixin:
# Encoder-Decoder checks
if config.is_encoder_decoder:
encoder_num_attention_heads = config.encoder_attention_heads
# encoder-decoder models usually don't have text config
# below is needed only for Pix2Struct which we cannot modify now due to BC
config = config.get_text_config()
encoder_num_attention_heads = (
config.encoder_attention_heads
if hasattr(config, "encoder_attention_heads")
else config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
for i in range(num_hidden_layers):
@@ -1804,14 +1823,14 @@ class GenerationTesterMixin:
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
# Let's make it always:
@@ -2251,7 +2270,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True