[test] update test_past_key_values_format (#37614)

allow custom shapes
This commit is contained in:
Joao Gante
2025-04-22 11:07:34 +01:00
committed by GitHub
parent 1cd110c6cb
commit 362fa37da2
19 changed files with 134 additions and 166 deletions

View File

@@ -429,9 +429,23 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
with self.assertRaises(AssertionError):
torch.testing.assert_close(yarn_sin_long, original_sin_long)
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
"""
Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
"""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, seq_length = inputs["input_ids"].shape
# difference: last dim
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
v_embed_dim = config.v_head_dim
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
# build the full cache shapes
num_hidden_layers = config.num_hidden_layers
all_cache_shapes = [
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
]
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@require_torch_sdpa
@slow

View File

@@ -264,51 +264,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_past_key_values_format(self):
# Falcon can have different numbers of KV-heads than the number of query heads, so we need
# to override this test to use the right head counts.
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# If it doesn't support cache, pass the test
if not hasattr(config, "use_cache"):
self.skipTest(reason="Model does not support cache")
model = model_class(config).to(torch_device)
if "use_cache" not in inputs:
inputs["use_cache"] = True
outputs = model(**inputs)
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
if "past_key_values" not in outputs:
self.skipTest(reason="Model does not return past_key_values")
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
)
num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"]
self.assertEqual(len(past_kv), num_hidden_layers)
batch_size, seq_length = inputs["input_ids"].shape
for i in range(num_hidden_layers):
if config.new_decoder_architecture:
num_attention_heads = config.num_attention_heads
elif config.multi_query:
num_attention_heads = 1
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
self.assertEqual(
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
self.assertEqual(
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
)
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
def test_model_rope_scaling_from_config(self, scaling_type):

View File

@@ -296,10 +296,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -264,10 +264,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@is_flaky()
def test_custom_4d_attention_mask(self):
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""

View File

@@ -222,12 +222,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
)
def test_past_key_values_format(self):
pass
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass

View File

@@ -319,6 +319,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_left_padding_compatibility(self):
pass
@unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :)
def test_past_key_values_format(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -251,10 +251,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -292,10 +292,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -324,10 +324,6 @@ class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@unittest.skip("Vocab resizing is not supported")
def test_save_load_after_resize_token_embeddings(self):
pass

View File

@@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -409,7 +409,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
pass
@pytest.mark.generate
# overridden because mllama has special cache for self and cross attentions
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
def test_past_key_values_format(self):
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
# standard KV cache format is important for a consistent API (and for advanced generation methods).

View File

@@ -303,10 +303,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -331,10 +331,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -306,10 +306,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
# Ignore copy
def test_past_key_values_format(self):
super().test_past_key_values_format()
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -325,10 +325,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
# Ignore copy
def test_past_key_values_format(self):
super().test_past_key_values_format()
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -223,10 +223,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="RecurrentGemma does not return pkv")
def test_past_key_values_format(self):
pass
@unittest.skip(reason="RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_generate(self):
pass

View File

@@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test

View File

@@ -322,14 +322,22 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Zamba2 has a hybrid cache")
def test_past_key_values_format(self):
r"""
Zamba2's cache shape depends on whether a given layer is mamba or attention.
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
"""
pass
Overwritting to pass the expected cache shapes (Zamba2 has cache shape = [batch_size, 0] for mamba layers)
"""
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, seq_length = inputs["input_ids"].shape
per_head_embed_dim = config.attention_head_dim # note: this one is not a common attribute name
self_attention_cache_shape = (batch_size, config.num_key_value_heads, seq_length, per_head_embed_dim)
# build the full cache shapes, including mamba layers
all_cache_shapes = []
for i in range(config.num_hidden_layers):
if config.layers_block_type[i] == "mamba":
all_cache_shapes.append([torch.Size([batch_size, 0]), torch.Size([batch_size, 0])])
else:
all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape])
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
@unittest.skip(reason="Zamba2 has hybrid cache.")
def test_generate_continue_from_inputs_embeds(self):