[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -20,7 +20,7 @@ from packaging import version
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
@@ -150,51 +150,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx if not use_cache else max_length
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
# overwrite because HybridCache has fixed length for key/values
|
||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
||||
self.assertIsInstance(past_key_values, HybridCache)
|
||||
|
||||
# check shape key, value (batch, head, max_seq_length, head_features)
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
|
||||
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
|
||||
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
|
||||
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
|
||||
|
||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user