[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -620,36 +620,42 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
self.assertIsNotNone(model)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
# NOTE (joao): this function is substancially different from the original, the attention has different
# *number* of shapes in certain conditions
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
tgt_chunk_len = config.local_attn_chunk_length
src_chunk_len = config.local_attn_chunk_length * (
for generated_length, iter_attentions in enumerate(attentions):
use_cache = decoder_past_key_values is not None and generated_length > 0
model_input_length = prompt_length + generated_length if not use_cache else 1
num_chunks = model_input_length // config.local_attn_chunk_length + (
model_input_length % config.local_attn_chunk_length != 0
)
model_input_chunk_len = config.local_attn_chunk_length
query_chunk_len = config.local_attn_chunk_length * (
1 + config.local_num_chunks_after + config.local_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
tgt_len,
min_length // config.local_attn_chunk_length + 1 + idx,
model_input_length,
prompt_length // config.local_attn_chunk_length + generated_length,
)
else:
expected_shape = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
num_chunks,
tgt_chunk_len,
src_chunk_len,
model_input_chunk_len,
query_chunk_len,
)
# check attn size
self.assertListEqual(
@@ -657,25 +663,29 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
# NOTE (joao): this function is substancially different from the original, the hidden states have different
# length in certain conditions
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx
seq_len = config.local_attn_chunk_length * (
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
for generation_length, iter_hidden_states in enumerate(hidden_states):
use_cache_this_iter = use_cache and generation_length > 0
model_input_length = prompt_length + generation_length
model_output_length = config.local_attn_chunk_length * (
model_input_length // config.local_attn_chunk_length
+ (model_input_length % config.local_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
if use_cache_this_iter:
model_output_length = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
expected_shape = (batch_size, model_output_length, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
@@ -789,37 +799,42 @@ class ReformerLSHAttnModelTest(
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
# NOTE (joao): this function is substancially different from the original, the attention has different
# *number* of shapes in certain conditions
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
tgt_chunk_len = config.lsh_attn_chunk_length
src_chunk_len = config.lsh_attn_chunk_length * (
for generated_length, iter_attentions in enumerate(attentions):
use_cache = decoder_past_key_values is not None and generated_length > 0
model_input_len = prompt_length + generated_length if not use_cache else 1
num_chunks = model_input_len // config.lsh_attn_chunk_length + (
model_input_len % config.lsh_attn_chunk_length != 0
)
model_input_chunk_len = config.lsh_attn_chunk_length
query_chunk_len = config.lsh_attn_chunk_length * (
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
config.num_hashes,
tgt_len,
model_input_len,
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
)
else:
expected_shape = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
num_chunks * config.num_hashes,
tgt_chunk_len,
src_chunk_len,
model_input_chunk_len,
query_chunk_len,
)
# check attn size
self.assertListEqual(
@@ -827,25 +842,29 @@ class ReformerLSHAttnModelTest(
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
# NOTE (joao): this function is substancially different from the original, the hidden states have different
# length in certain conditions
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
seq_len = config.lsh_attn_chunk_length * (
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
for generation_length, iter_hidden_states in enumerate(hidden_states):
use_cache_this_iter = use_cache and generation_length > 0
model_input_length = prompt_length + generation_length
model_output_length = config.local_attn_chunk_length * (
model_input_length // config.local_attn_chunk_length
+ (model_input_length % config.local_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
if use_cache_this_iter:
model_output_length = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
expected_shape = (batch_size, model_output_length, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],