VLM: enable skipped tests (#35746)
* fix cached tests * fix some tests * fix pix2struct * fix
This commit is contained in:
committed by
GitHub
parent
d6897b46bd
commit
8fc6ecba4f
@@ -82,14 +82,14 @@ class AriaVisionText2TextModelTester:
|
||||
moe_intermediate_size=4,
|
||||
moe_num_experts=4,
|
||||
moe_topk=2,
|
||||
num_attention_heads=20,
|
||||
num_attention_heads=8,
|
||||
num_experts_per_tok=3,
|
||||
num_hidden_layers=2,
|
||||
num_key_value_heads=20,
|
||||
num_key_value_heads=8,
|
||||
rope_theta=5000000,
|
||||
vocab_size=99,
|
||||
eos_token_id=2,
|
||||
head_dim=2,
|
||||
head_dim=4,
|
||||
),
|
||||
is_training=True,
|
||||
vision_config=Idefics3VisionConfig(
|
||||
|
||||
@@ -29,6 +29,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
@@ -378,6 +379,105 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Mllama cache type doesn't allow correct check on output `past_key_values` due to `Cache.crop()`"
|
||||
)
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama can't do low memory due to `Cache.crop()`")
|
||||
def test_contrastive_generate_low_memory(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama can't assisted decoding due to cache format and `Cache.crop()`")
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
# overriden because mllama has special cache for self and cross attentions
|
||||
def test_past_key_values_format(self):
|
||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
inputs["use_cache"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
text_config = config.get_text_config()
|
||||
num_hidden_layers = (
|
||||
getattr(text_config, "decoder_layers", None)
|
||||
or getattr(text_config, "num_decoder_layers", None)
|
||||
or text_config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
# some models have diffent num-head for query vs key/value so we need to assign correct value
|
||||
# BUT only after `per_head_embed_dim` is set
|
||||
num_attention_heads = (
|
||||
text_config.num_key_value_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||
else num_attention_heads
|
||||
)
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||
if i in self.model_tester.text_config["cross_attention_layers"]:
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape,
|
||||
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape,
|
||||
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
# overriden because mllama has special cache for self and cross attentions
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, Cache)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||
[True] * len(decoder_past_key_values),
|
||||
)
|
||||
|
||||
for layer_idx, layer_past_key_values in enumerate(decoder_past_key_values):
|
||||
if layer_idx in self.model_tester.text_config["cross_attention_layers"]:
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads
|
||||
if hasattr(config, "num_key_value_heads")
|
||||
else config.num_attention_heads,
|
||||
self.model_tester.image_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
else:
|
||||
# (batch, head, cache_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads
|
||||
if hasattr(config, "num_key_value_heads")
|
||||
else config.num_attention_heads,
|
||||
cache_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual([layer_past_key_values[0].shape], [expected_shape])
|
||||
self.assertListEqual([layer_past_key_values[1].shape], [expected_shape])
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
@@ -612,6 +612,18 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||
)
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||
)
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
PaliGemmaConfig,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
@@ -348,3 +350,40 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user