[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -456,51 +456,26 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
# GIT attention shape depends on image inputs, overwrite
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx + image_length if not use_cache else 1
|
||||
src_len = min_length + idx + image_length
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
prompt_length += image_length
|
||||
output_length += image_length
|
||||
super()._check_attentions_for_generate(
|
||||
batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
# GIT attention shape depends on image inputs, overwrite
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx + image_length if not use_cache else 1
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
prompt_length += image_length
|
||||
output_length += image_length
|
||||
super()._check_hidden_states_for_generate(
|
||||
batch_size, hidden_states, prompt_length, output_length, config, use_cache=use_cache
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
Reference in New Issue
Block a user