[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -399,11 +399,11 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
],
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
encoder_expected_shape = (
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
math.ceil(seq_length / config.block_size),
|
||||
math.ceil(prompt_length / config.block_size),
|
||||
config.block_size,
|
||||
config.block_size + config.num_global_tokens,
|
||||
)
|
||||
@@ -413,8 +413,8 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
encoder_expected_shape = (batch_size, self.round_up(seq_length, config.block_size), config.hidden_size)
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||
encoder_expected_shape = (batch_size, self.round_up(prompt_length, config.block_size), config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
# Only the last layer will have the hidden states truncated back to token level
|
||||
self.assertListEqual(
|
||||
@@ -424,7 +424,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
# Only the last layer will have the hidden states truncated back to token level
|
||||
self.assertEqual(
|
||||
hidden_states[-1][0].shape,
|
||||
(batch_size, seq_length, config.hidden_size),
|
||||
(batch_size, prompt_length, config.hidden_size),
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
|
||||
Reference in New Issue
Block a user