[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -636,57 +636,52 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
weight.data.fill_(3)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
for idx, iter_hidden_states in enumerate(hidden_states):
for generated_length, iter_hidden_states in enumerate(hidden_states):
# check hidden size
for i, layer_hidden_states in enumerate(iter_hidden_states):
# every 2nd tensor is from extra stream
if i % 2 != 0:
seq_len = 1
model_output_length = 1
else:
# for first item dummy PAD token is appended so need one more
# else offset+dummy_token when using cache
seq_len = (min_length + 1) if idx == 0 else 3
model_output_length = (prompt_length + 1) if generated_length == 0 else 3
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
expected_shape = (batch_size, model_output_length, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
self.assertEqual(len(attentions), (output_length - prompt_length))
for idx, attentions_item in enumerate(attentions):
for generated_length, attentions_item in enumerate(attentions):
for iter_attentions in attentions_item:
tgt_len = min_length
model_input_length = prompt_length
# for first item dummy PAD token is appended so need one more
# every token after consists of offset+dummy_token length when using cache
if idx == 0:
tgt_len += 1
if generated_length == 0:
model_input_length += 1
else:
tgt_len = 3
model_input_length = 3
src_len = min_length + idx + 1
query_length = prompt_length + generated_length + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
expected_shape = (batch_size, config.num_attention_heads, model_input_length, query_length)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions],