[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -636,57 +636,52 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
weight.data.fill_(3)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||
# check hidden size
|
||||
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
||||
# every 2nd tensor is from extra stream
|
||||
if i % 2 != 0:
|
||||
seq_len = 1
|
||||
model_output_length = 1
|
||||
else:
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
# else offset+dummy_token when using cache
|
||||
seq_len = (min_length + 1) if idx == 0 else 3
|
||||
model_output_length = (prompt_length + 1) if generated_length == 0 else 3
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
for idx, attentions_item in enumerate(attentions):
|
||||
for generated_length, attentions_item in enumerate(attentions):
|
||||
for iter_attentions in attentions_item:
|
||||
tgt_len = min_length
|
||||
model_input_length = prompt_length
|
||||
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
# every token after consists of offset+dummy_token length when using cache
|
||||
if idx == 0:
|
||||
tgt_len += 1
|
||||
if generated_length == 0:
|
||||
model_input_length += 1
|
||||
else:
|
||||
tgt_len = 3
|
||||
model_input_length = 3
|
||||
|
||||
src_len = min_length + idx + 1
|
||||
query_length = prompt_length + generated_length + 1
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
expected_shape = (batch_size, config.num_attention_heads, model_input_length, query_length)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions],
|
||||
|
||||
Reference in New Issue
Block a user