Generate: add test to check KV format (#23403)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1645,6 +1645,77 @@ class GenerationTesterMixin:
|
||||
|
||||
self.assertTrue(no_failures)
|
||||
|
||||
def test_past_key_values_format(self):
|
||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
return
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
inputs["use_cache"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
if "past_key_values" not in outputs:
|
||||
return
|
||||
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
or getattr(config, "num_decoder_layers", None)
|
||||
or config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
|
||||
embed_dim = getattr(config, "d_model", config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
|
||||
# Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
encoder_num_attention_heads = config.encoder_attention_heads
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
||||
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim
|
||||
self.assertEqual(
|
||||
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]),
|
||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
||||
)
|
||||
self.assertEqual(
|
||||
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]),
|
||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
||||
)
|
||||
|
||||
# Decoder-only checks
|
||||
else:
|
||||
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
|
||||
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
|
||||
# tests use it)
|
||||
key = "input_ids" if "input_ids" in inputs else "pixel_values"
|
||||
batch_size, seq_length = inputs[key].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
|
||||
@@ -393,6 +393,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs)
|
||||
|
||||
@unittest.skip("Bloom has a non-standard KV cache format.")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
@@ -475,6 +475,10 @@ class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
def test_disk_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("BigCodeGPT has a non-standard KV cache format.")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
def test_gpt_bigcode_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
|
||||
|
||||
@@ -831,8 +831,12 @@ class ReformerLSHAttnModelTest(
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
@unittest.skip("Fails because the sequence length is not a multiple of 4")
|
||||
def test_problem_types(self):
|
||||
# Fails because the sequence length is not a multiple of 4
|
||||
pass
|
||||
|
||||
@unittest.skip("Fails because the sequence length is not a multiple of 4")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user