[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -323,32 +323,37 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
torch.testing.assert_close(out_embeds, out_ids)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
# Mllama has cross attention layers and those have a different shape than normal attention layers
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(attentions), (output_length - prompt_length))
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
use_cache = decoder_past_key_values is not None
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = prompt_length + generated_length
expected_shape = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
tgt_len,
src_len,
model_input_length,
query_length,
)
expected_shape_cross = (
batch_size * num_beam_groups,
batch_size,
config.num_attention_heads,
tgt_len,
model_input_length,
self.model_tester.image_length,
)