[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -20,7 +20,7 @@ from packaging import version
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, HybridCache, is_torch_available, pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
require_flash_attn,
@@ -135,51 +135,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_generate_continue_from_inputs_embeds(self):
pass
# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)
# check shape key, value (batch, head, max_seq_length, head_features)
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
num_hidden_layers = config.num_hidden_layers
# we should get `max_length` in shape, not `max_length - embeds_length`
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass