@@ -1539,92 +1539,133 @@ class GenerationTesterMixin:
|
||||
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_past_key_values_format(self):
|
||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
def test_past_key_values_format(self, custom_all_cache_shapes=None):
|
||||
"""
|
||||
Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
|
||||
expected cache shapes.
|
||||
Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
# 1. If it doesn't support cache, skip the test
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model = model.eval()
|
||||
if "use_cache" not in inputs:
|
||||
inputs["use_cache"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
# 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
|
||||
past_kv = outputs["past_key_values"]
|
||||
is_legacy_cache = not isinstance(past_kv, Cache)
|
||||
|
||||
text_config = config.get_text_config()
|
||||
num_hidden_layers = (
|
||||
num_decoder_layers = (
|
||||
getattr(text_config, "decoder_layers", None)
|
||||
or getattr(text_config, "num_decoder_layers", None)
|
||||
or text_config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
# some models have different num-head for query vs key/value so we need to assign correct value
|
||||
# BUT only after `per_head_embed_dim` is set
|
||||
num_attention_heads = (
|
||||
text_config.num_key_value_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||
else num_attention_heads
|
||||
)
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
|
||||
# Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
# encoder-decoder models usually don't have text config
|
||||
# below is needed only for Pix2Struct which we cannot modify now due to BC
|
||||
config = config.get_text_config()
|
||||
encoder_num_attention_heads = (
|
||||
config.encoder_attention_heads
|
||||
if hasattr(config, "encoder_attention_heads")
|
||||
else config.num_attention_heads
|
||||
if custom_all_cache_shapes is None:
|
||||
num_query_attention_heads = getattr(
|
||||
text_config, "decoder_attention_heads", text_config.num_attention_heads
|
||||
)
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_query_attention_heads
|
||||
num_key_value_heads = (
|
||||
text_config.num_key_value_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||
else num_query_attention_heads
|
||||
)
|
||||
if config.is_encoder_decoder:
|
||||
encoder_num_attention_heads = (
|
||||
text_config.encoder_attention_heads
|
||||
if hasattr(text_config, "encoder_attention_heads")
|
||||
else text_config.num_attention_heads
|
||||
)
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
|
||||
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim
|
||||
self.assertEqual(
|
||||
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]),
|
||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
||||
)
|
||||
self.assertEqual(
|
||||
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]),
|
||||
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
|
||||
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
|
||||
default_cross_attention_shape = (
|
||||
batch_size,
|
||||
encoder_num_attention_heads,
|
||||
encoder_per_head_embed_dim,
|
||||
)
|
||||
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||
all_cache_shapes = [
|
||||
[
|
||||
default_self_attention_shape,
|
||||
default_self_attention_shape,
|
||||
default_cross_attention_shape,
|
||||
default_cross_attention_shape,
|
||||
]
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
|
||||
all_cache_shapes = [
|
||||
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
|
||||
]
|
||||
|
||||
# Decoder-only checks
|
||||
else:
|
||||
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
|
||||
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
|
||||
# tests use it)
|
||||
key = "input_ids" if "input_ids" in inputs else "pixel_values"
|
||||
batch_size, seq_length = inputs[key].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
all_cache_shapes = custom_all_cache_shapes
|
||||
|
||||
# 3. Check cache shapes
|
||||
# 3.1. Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
num_cache_decoder_layers = (
|
||||
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
|
||||
)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
if is_legacy_cache:
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
self_attention_layer_value_cache = (
|
||||
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
|
||||
)
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
|
||||
# Cross attention (ignore 3rd dim, see default shape preparation)
|
||||
cross_attention_layer_key_cache = (
|
||||
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
|
||||
)
|
||||
cross_attention_layer_value_cache = (
|
||||
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
|
||||
)
|
||||
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
|
||||
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
|
||||
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
|
||||
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
|
||||
|
||||
# 3.2. Decoder-only checks
|
||||
else:
|
||||
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
|
||||
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
if is_legacy_cache:
|
||||
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
|
||||
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
|
||||
Reference in New Issue
Block a user