[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -399,11 +399,11 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
],
)
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
encoder_expected_shape = (
batch_size,
config.num_attention_heads,
math.ceil(seq_length / config.block_size),
math.ceil(prompt_length / config.block_size),
config.block_size,
config.block_size + config.num_global_tokens,
)
@@ -413,8 +413,8 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
[encoder_expected_shape] * len(attentions),
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
encoder_expected_shape = (batch_size, self.round_up(seq_length, config.block_size), config.hidden_size)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
encoder_expected_shape = (batch_size, self.round_up(prompt_length, config.block_size), config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
# Only the last layer will have the hidden states truncated back to token level
self.assertListEqual(
@@ -424,7 +424,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
# Only the last layer will have the hidden states truncated back to token level
self.assertEqual(
hidden_states[-1][0].shape,
(batch_size, seq_length, config.hidden_size),
(batch_size, prompt_length, config.hidden_size),
)
def test_hidden_states_output(self):