[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -323,32 +323,37 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
# Mllama has cross attention layers and those have a different shape than normal attention layers
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
|
||||
use_cache = decoder_past_key_values is not None
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx
|
||||
for generated_length, iter_attentions in enumerate(attentions):
|
||||
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||
if use_cache and generated_length > 0:
|
||||
model_input_length = 1
|
||||
else:
|
||||
model_input_length = prompt_length + generated_length
|
||||
query_length = prompt_length + generated_length
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
model_input_length,
|
||||
query_length,
|
||||
)
|
||||
|
||||
expected_shape_cross = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
model_input_length,
|
||||
self.model_tester.image_length,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user