[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -753,20 +753,20 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
# overwrite because # pix2struct seq length depends on image inputs
|
||||
seq_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||
prompt_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||
# overwrite because # pix2struct seq length depends on image inputs
|
||||
seq_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
prompt_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
|
||||
Reference in New Issue
Block a user