From 362fa37da239d4d4dea456850cabc79c8a2ffa16 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 22 Apr 2025 11:07:34 +0100 Subject: [PATCH] [test] update `test_past_key_values_format` (#37614) allow custom shapes --- tests/generation/test_utils.py | 157 +++++++++++------- .../deepseek_v3/test_modeling_deepseek_v3.py | 18 +- tests/models/falcon/test_modeling_falcon.py | 45 ----- tests/models/gemma/test_modeling_gemma.py | 4 - tests/models/glm/test_modeling_glm.py | 4 - .../models/got_ocr2/test_modeling_got_ocr2.py | 6 - .../models/imagegpt/test_modeling_imagegpt.py | 4 + tests/models/jetmoe/test_modeling_jetmoe.py | 4 - tests/models/mistral/test_modeling_mistral.py | 4 - .../mistral/test_modeling_tf_mistral.py | 4 - tests/models/mixtral/test_modeling_mixtral.py | 4 - tests/models/mllama/test_modeling_mllama.py | 2 +- tests/models/qwen2/test_modeling_qwen2.py | 4 - .../qwen2_moe/test_modeling_qwen2_moe.py | 4 - tests/models/qwen3/test_modeling_qwen3.py | 4 - .../qwen3_moe/test_modeling_qwen3_moe.py | 4 - .../test_modeling_recurrent_gemma.py | 4 - .../starcoder2/test_modeling_starcoder2.py | 4 - tests/models/zamba2/test_modeling_zamba2.py | 20 ++- 19 files changed, 134 insertions(+), 166 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c5794f52a0..caae8738e3 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1539,92 +1539,133 @@ class GenerationTesterMixin: torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) @pytest.mark.generate - def test_past_key_values_format(self): - # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a - # standard KV cache format is important for a consistent API (and for advanced generation methods). + def test_past_key_values_format(self, custom_all_cache_shapes=None): + """ + Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the + expected cache shapes. + Having a standard KV cache format is important for a consistent API (and for advanced generation methods). + """ for model_class in self.all_generative_model_classes: config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - # If it doesn't support cache, pass the test + # 1. If it doesn't support cache, skip the test if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") model = model_class(config).to(torch_device) + model = model.eval() if "use_cache" not in inputs: inputs["use_cache"] = True outputs = model(**inputs) - # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) if "past_key_values" not in outputs: self.skipTest(reason="This model doesn't return `past_key_values`") + # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided) + past_kv = outputs["past_key_values"] + is_legacy_cache = not isinstance(past_kv, Cache) + text_config = config.get_text_config() - num_hidden_layers = ( + num_decoder_layers = ( getattr(text_config, "decoder_layers", None) or getattr(text_config, "num_decoder_layers", None) or text_config.num_hidden_layers ) - num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads) - embed_dim = getattr(text_config, "d_model", text_config.hidden_size) - per_head_embed_dim = embed_dim // num_attention_heads - # some models have different num-head for query vs key/value so we need to assign correct value - # BUT only after `per_head_embed_dim` is set - num_attention_heads = ( - text_config.num_key_value_heads - if getattr(text_config, "num_key_value_heads", None) is not None - else num_attention_heads - ) - - past_kv = outputs["past_key_values"] - self.assertEqual(len(past_kv), num_hidden_layers) - - # Encoder-Decoder checks - if config.is_encoder_decoder: - # encoder-decoder models usually don't have text config - # below is needed only for Pix2Struct which we cannot modify now due to BC - config = config.get_text_config() - encoder_num_attention_heads = ( - config.encoder_attention_heads - if hasattr(config, "encoder_attention_heads") - else config.num_attention_heads + if custom_all_cache_shapes is None: + num_query_attention_heads = getattr( + text_config, "decoder_attention_heads", text_config.num_attention_heads ) - encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads - batch_size, seq_length = inputs["decoder_input_ids"].shape - for i in range(num_hidden_layers): - self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4 - self.assertEqual( - past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) - ) - self.assertEqual( - past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + embed_dim = getattr(text_config, "d_model", text_config.hidden_size) + per_head_embed_dim = embed_dim // num_query_attention_heads + num_key_value_heads = ( + text_config.num_key_value_heads + if getattr(text_config, "num_key_value_heads", None) is not None + else num_query_attention_heads + ) + if config.is_encoder_decoder: + encoder_num_attention_heads = ( + text_config.encoder_attention_heads + if hasattr(text_config, "encoder_attention_heads") + else text_config.num_attention_heads ) + encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads + batch_size, seq_length = inputs["decoder_input_ids"].shape # The sequence length for the encoder K V depends on the model. Since it is not manipulated in - # autoregressive generation, I'm keeping the test general and not checking the 3rd dim - self.assertEqual( - (past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]), - (batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), - ) - self.assertEqual( - (past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]), - (batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim), + # autoregressive generation, we're keeping the test general and not checking the 3rd dim + default_cross_attention_shape = ( + batch_size, + encoder_num_attention_heads, + encoder_per_head_embed_dim, ) + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [ + default_self_attention_shape, + default_self_attention_shape, + default_cross_attention_shape, + default_cross_attention_shape, + ] + for _ in range(num_decoder_layers) + ] + else: + batch_size, seq_length = inputs["input_ids"].shape + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) + ] - # Decoder-only checks else: - # TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the - # tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other - # tests use it) - key = "input_ids" if "input_ids" in inputs else "pixel_values" - batch_size, seq_length = inputs[key].shape - for i in range(num_hidden_layers): - self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2 - self.assertEqual( - past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + all_cache_shapes = custom_all_cache_shapes + + # 3. Check cache shapes + # 3.1. Encoder-Decoder checks + if config.is_encoder_decoder: + num_cache_decoder_layers = ( + len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) + ) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + + for i in range(num_decoder_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple + + # Self attention + self_attention_layer_key_cache = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] ) - self.assertEqual( - past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + self_attention_layer_value_cache = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] ) + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + + # Cross attention (ignore 3rd dim, see default shape preparation) + cross_attention_layer_key_cache = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + ) + cross_attention_layer_value_cache = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + ) + cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] + cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] + self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + + # 3.2. Decoder-only checks + else: + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + + for i in range(num_decoder_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple + + # Self attention + self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] + self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 5e5737df5a..1c2690b54e 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -429,9 +429,23 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): - pass + """ + Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard) + """ + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + batch_size, seq_length = inputs["input_ids"].shape + # difference: last dim + k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + v_embed_dim = config.v_head_dim + self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) + self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) + # build the full cache shapes + num_hidden_layers = config.num_hidden_layers + all_cache_shapes = [ + [self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers) + ] + super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @require_torch_sdpa @slow diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index cf52de0036..6a63177476 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -264,51 +264,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - def test_past_key_values_format(self): - # Falcon can have different numbers of KV-heads than the number of query heads, so we need - # to override this test to use the right head counts. - for model_class in self.all_generative_model_classes: - config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - - # If it doesn't support cache, pass the test - if not hasattr(config, "use_cache"): - self.skipTest(reason="Model does not support cache") - - model = model_class(config).to(torch_device) - if "use_cache" not in inputs: - inputs["use_cache"] = True - outputs = model(**inputs) - - # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) - if "past_key_values" not in outputs: - self.skipTest(reason="Model does not return past_key_values") - - num_hidden_layers = ( - getattr(config, "decoder_layers", None) - or getattr(config, "num_decoder_layers", None) - or config.num_hidden_layers - ) - num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads) - embed_dim = getattr(config, "d_model", config.hidden_size) - per_head_embed_dim = embed_dim // num_attention_heads - - past_kv = outputs["past_key_values"] - self.assertEqual(len(past_kv), num_hidden_layers) - - batch_size, seq_length = inputs["input_ids"].shape - for i in range(num_hidden_layers): - if config.new_decoder_architecture: - num_attention_heads = config.num_attention_heads - elif config.multi_query: - num_attention_heads = 1 - self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2 - self.assertEqual( - past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) - ) - self.assertEqual( - past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) - ) - @parameterized.expand([("linear",), ("dynamic",)]) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon def test_model_rope_scaling_from_config(self, scaling_type): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 4b7293817a..ce0aadd163 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -296,10 +296,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py index f4dc0ab81b..9e8eda5cb2 100644 --- a/tests/models/glm/test_modeling_glm.py +++ b/tests/models/glm/test_modeling_glm.py @@ -264,10 +264,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @is_flaky() def test_custom_4d_attention_mask(self): """Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky.""" diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py index f9ea723fd9..f604dbf036 100644 --- a/tests/models/got_ocr2/test_modeling_got_ocr2.py +++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py @@ -222,12 +222,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip( - reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format" - ) - def test_past_key_values_format(self): - pass - @unittest.skip("FlashAttention only support fp16 and bf16 data type") def test_flash_attn_2_fp32_ln(self): pass diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index 98b1640d62..c20d00e733 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -319,6 +319,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM def test_left_padding_compatibility(self): pass + @unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :) + def test_past_key_values_format(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index d3b17c830f..0dfc7e2cef 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -251,10 +251,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index ead979352a..7eee96f2ef 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -292,10 +292,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mistral/test_modeling_tf_mistral.py b/tests/models/mistral/test_modeling_tf_mistral.py index beff383a98..aec7c6f23f 100644 --- a/tests/models/mistral/test_modeling_tf_mistral.py +++ b/tests/models/mistral/test_modeling_tf_mistral.py @@ -324,10 +324,6 @@ class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @unittest.skip("Vocab resizing is not supported") def test_save_load_after_resize_token_embeddings(self): pass diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index b4b3b38edf..2d7c95529b 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 711101f766..6308f6d4c0 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -409,7 +409,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester pass @pytest.mark.generate - # overridden because mllama has special cache for self and cross attentions + # overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache def test_past_key_values_format(self): # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a # standard KV cache format is important for a consistent API (and for advanced generation methods). diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index d9fc25efc3..1339a09b64 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -303,10 +303,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index b89382867e..ecea6a3497 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -331,10 +331,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 917713a217..0a5660ecd2 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -306,10 +306,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - # Ignore copy - def test_past_key_values_format(self): - super().test_past_key_values_format() - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py index b192afae1e..c14f71407d 100644 --- a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py +++ b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py @@ -325,10 +325,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - # Ignore copy - def test_past_key_values_format(self): - super().test_past_key_values_format() - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index b15bf26e06..4a41cecc0a 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -223,10 +223,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip(reason="RecurrentGemma does not return pkv") - def test_past_key_values_format(self): - pass - @unittest.skip(reason="RecurrentGemma only supports sdpa") def test_eager_matches_sdpa_generate(self): pass diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 7c98425cfb..dbc3c0dc80 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - @unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 8c9b07ce89..78079293f3 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -322,14 +322,22 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip("Zamba2 has a hybrid cache") def test_past_key_values_format(self): - r""" - Zamba2's cache shape depends on whether a given layer is mamba or attention. - For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0]. - The shape checks of this test assume instead that every layer has an attention cache, so we skip it. """ - pass + Overwritting to pass the expected cache shapes (Zamba2 has cache shape = [batch_size, 0] for mamba layers) + """ + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + batch_size, seq_length = inputs["input_ids"].shape + per_head_embed_dim = config.attention_head_dim # note: this one is not a common attribute name + self_attention_cache_shape = (batch_size, config.num_key_value_heads, seq_length, per_head_embed_dim) + # build the full cache shapes, including mamba layers + all_cache_shapes = [] + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "mamba": + all_cache_shapes.append([torch.Size([batch_size, 0]), torch.Size([batch_size, 0])]) + else: + all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape]) + super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes) @unittest.skip(reason="Zamba2 has hybrid cache.") def test_generate_continue_from_inputs_embeds(self):