@@ -429,9 +429,23 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
"""
|
||||
Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
|
||||
"""
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
# difference: last dim
|
||||
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
v_embed_dim = config.v_head_dim
|
||||
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
# build the full cache shapes
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
all_cache_shapes = [
|
||||
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||
]
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
|
||||
@@ -264,51 +264,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_past_key_values_format(self):
|
||||
# Falcon can have different numbers of KV-heads than the number of query heads, so we need
|
||||
# to override this test to use the right head counts.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest(reason="Model does not support cache")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
inputs["use_cache"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="Model does not return past_key_values")
|
||||
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
or getattr(config, "num_decoder_layers", None)
|
||||
or config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
|
||||
embed_dim = getattr(config, "d_model", config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
if config.new_decoder_architecture:
|
||||
num_attention_heads = config.num_attention_heads
|
||||
elif config.multi_query:
|
||||
num_attention_heads = 1
|
||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
|
||||
@@ -296,10 +296,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -264,10 +264,6 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@is_flaky()
|
||||
def test_custom_4d_attention_mask(self):
|
||||
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""
|
||||
|
||||
@@ -222,12 +222,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
||||
)
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@@ -319,6 +319,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
def test_left_padding_compatibility(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :)
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
@@ -251,10 +251,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -292,10 +292,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -324,10 +324,6 @@ class TFMistralModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Vocab resizing is not supported")
|
||||
def test_save_load_after_resize_token_embeddings(self):
|
||||
pass
|
||||
|
||||
@@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -409,7 +409,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
# overridden because mllama has special cache for self and cross attentions
|
||||
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
|
||||
def test_past_key_values_format(self):
|
||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
|
||||
@@ -303,10 +303,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Qwen2 uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -331,10 +331,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Qwen2Moe uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -306,10 +306,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def test_past_key_values_format(self):
|
||||
super().test_past_key_values_format()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -325,10 +325,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def test_past_key_values_format(self):
|
||||
super().test_past_key_values_format()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -223,10 +223,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma does not return pkv")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma only supports sdpa")
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
pass
|
||||
|
||||
@@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
@@ -322,14 +322,22 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Zamba2 has a hybrid cache")
|
||||
def test_past_key_values_format(self):
|
||||
r"""
|
||||
Zamba2's cache shape depends on whether a given layer is mamba or attention.
|
||||
For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0].
|
||||
The shape checks of this test assume instead that every layer has an attention cache, so we skip it.
|
||||
"""
|
||||
pass
|
||||
Overwritting to pass the expected cache shapes (Zamba2 has cache shape = [batch_size, 0] for mamba layers)
|
||||
"""
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
per_head_embed_dim = config.attention_head_dim # note: this one is not a common attribute name
|
||||
self_attention_cache_shape = (batch_size, config.num_key_value_heads, seq_length, per_head_embed_dim)
|
||||
# build the full cache shapes, including mamba layers
|
||||
all_cache_shapes = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
if config.layers_block_type[i] == "mamba":
|
||||
all_cache_shapes.append([torch.Size([batch_size, 0]), torch.Size([batch_size, 0])])
|
||||
else:
|
||||
all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape])
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
@unittest.skip(reason="Zamba2 has hybrid cache.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
|
||||
Reference in New Issue
Block a user