VLM: enable skipped tests (#35746)

* fix cached tests

* fix some tests

* fix pix2struct

* fix
This commit is contained in:
Raushan Turganbay
2025-02-12 12:55:46 +01:00
committed by GitHub
parent d6897b46bd
commit 8fc6ecba4f
10 changed files with 216 additions and 20 deletions

View File

@@ -579,6 +579,9 @@ BLIP_2_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
@@ -2094,6 +2097,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
r"""
Returns:
@@ -2217,6 +2221,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@@ -2242,6 +2247,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=True, # toggle for easier access to loss/logits below
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss
logits = outputs.logits

View File

@@ -441,6 +441,9 @@ INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
@@ -1375,6 +1378,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1485,6 +1489,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@@ -1510,6 +1515,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]

View File

@@ -1265,6 +1265,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
@@ -1369,6 +1372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1512,6 +1516,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@@ -1537,6 +1542,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]

View File

@@ -188,6 +188,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
```python
@@ -322,6 +323,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
@@ -347,6 +349,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
use_cache=use_cache,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]

View File

@@ -1694,6 +1694,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
past_key_values=None,
attention_mask=None,
use_cache=None,
cache_position=None,
**model_kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1704,17 +1705,21 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
attention_mask = input_ids.new_ones(input_shape)
position_ids = None
if cache_position is None:
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
# cut input_ids if past_key_values is used
if past_key_values is not None:
position_ids = create_position_ids_from_input_ids(
input_ids,
padding_idx=self.config.pad_token_id,
past_key_values_length=0,
)[:, -1:]
)
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
position_ids = position_ids[:, -input_ids.shape[1] :]
input_ids = input_ids[:, -1:]
# the image info. is already encoded into the past keys/values
image_embeds = None
image_embeds_position_mask = None
elif image_embeds_position_mask is not None:

View File

@@ -516,7 +516,7 @@ class GenerationTesterMixin:
if self.has_attentions:
config._attn_implementation = "eager" # can't output attentions otherwise
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -651,7 +651,7 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -989,7 +989,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1018,7 +1018,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
if self.has_attentions:
@@ -1060,7 +1060,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1179,6 +1179,10 @@ class GenerationTesterMixin:
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1187,7 +1191,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1254,6 +1258,10 @@ class GenerationTesterMixin:
"seamlessm4t",
"clvp",
"fuyu",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1262,7 +1270,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1368,6 +1376,10 @@ class GenerationTesterMixin:
"prophetnet",
"seamlessm4t",
"clvp",
"mllama", # special cache sizes
"blip2", # overridden `generate()`
"instructblip",
"instructblipvideo",
]
):
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1376,7 +1388,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.is_decoder = True
@@ -1570,7 +1582,7 @@ class GenerationTesterMixin:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
model = model_class(config).to(torch_device)
@@ -1605,7 +1617,14 @@ class GenerationTesterMixin:
# Encoder-Decoder checks
if config.is_encoder_decoder:
encoder_num_attention_heads = config.encoder_attention_heads
# encoder-decoder models usually don't have text config
# below is needed only for Pix2Struct which we cannot modify now due to BC
config = config.get_text_config()
encoder_num_attention_heads = (
config.encoder_attention_heads
if hasattr(config, "encoder_attention_heads")
else config.num_attention_heads
)
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
batch_size, seq_length = inputs["decoder_input_ids"].shape
for i in range(num_hidden_layers):
@@ -1804,14 +1823,14 @@ class GenerationTesterMixin:
def test_generate_continue_from_past_key_values(self):
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
# Let's make it always:
@@ -2251,7 +2270,7 @@ class GenerationTesterMixin:
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True

View File

@@ -82,14 +82,14 @@ class AriaVisionText2TextModelTester:
moe_intermediate_size=4,
moe_num_experts=4,
moe_topk=2,
num_attention_heads=20,
num_attention_heads=8,
num_experts_per_tok=3,
num_hidden_layers=2,
num_key_value_heads=20,
num_key_value_heads=8,
rope_theta=5000000,
vocab_size=99,
eos_token_id=2,
head_dim=2,
head_dim=4,
),
is_training=True,
vision_config=Idefics3VisionConfig(

View File

@@ -29,6 +29,7 @@ from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.cache_utils import Cache
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.testing_utils import (
cleanup,
@@ -378,6 +379,105 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
def test_offloaded_cache_implementation(self, cache_implementation):
pass
@unittest.skip(
reason="Mllama cache type doesn't allow correct check on output `past_key_values` due to `Cache.crop()`"
)
def test_contrastive_generate_dict_outputs_use_cache(self, assistant_type):
pass
@unittest.skip(reason="Mllama can't do low memory due to `Cache.crop()`")
def test_contrastive_generate_low_memory(self, assistant_type):
pass
@unittest.skip(reason="Mllama can't assisted decoding due to cache format and `Cache.crop()`")
def test_assisted_decoding_with_num_logits_to_keep(self):
pass
@pytest.mark.generate
# overriden because mllama has special cache for self and cross attentions
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
text_config = config.get_text_config()
num_hidden_layers = (
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
# some models have diffent num-head for query vs key/value so we need to assign correct value
# BUT only after `per_head_embed_dim` is set
num_attention_heads = (
text_config.num_key_value_heads
if getattr(text_config, "num_key_value_heads", None) is not None
else num_attention_heads
)
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)
batch_size, seq_length = inputs["input_ids"].shape
for i in range(num_hidden_layers):
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
if i in self.model_tester.text_config["cross_attention_layers"]:
self.assertEqual(
past_kv[i][0].shape,
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
)
self.assertEqual(
past_kv[i][1].shape,
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
)
else:
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
# overriden because mllama has special cache for self and cross attentions
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, Cache)
self.assertListEqual(
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
[True] * len(decoder_past_key_values),
)
for layer_idx, layer_past_key_values in enumerate(decoder_past_key_values):
if layer_idx in self.model_tester.text_config["cross_attention_layers"]:
expected_shape = (
batch_size,
config.num_key_value_heads
if hasattr(config, "num_key_value_heads")
else config.num_attention_heads,
self.model_tester.image_length,
config.hidden_size // config.num_attention_heads,
)
else:
# (batch, head, cache_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads
if hasattr(config, "num_key_value_heads")
else config.num_attention_heads,
cache_length,
config.hidden_size // config.num_attention_heads,
)
# check shape key, value
self.assertListEqual([layer_past_key_values[0].shape], [expected_shape])
self.assertListEqual([layer_past_key_values[1].shape], [expected_shape])
def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature

View File

@@ -612,6 +612,18 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
)
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
)
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa

View File

@@ -16,6 +16,8 @@
import unittest
from parameterized import parameterized
from transformers import (
PaliGemmaConfig,
PaliGemmaForConditionalGeneration,
@@ -348,3 +350,40 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
@unittest.skip("Low memory will be removed soon so no need to fix it")
def test_beam_search_low_memory(self):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
def test_generate_with_static_cache(self):
pass