[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -620,36 +620,42 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
# NOTE (joao): this function is substancially different from the original, the attention has different
|
||||
# *number* of shapes in certain conditions
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.local_attn_chunk_length
|
||||
src_chunk_len = config.local_attn_chunk_length * (
|
||||
for generated_length, iter_attentions in enumerate(attentions):
|
||||
use_cache = decoder_past_key_values is not None and generated_length > 0
|
||||
|
||||
model_input_length = prompt_length + generated_length if not use_cache else 1
|
||||
num_chunks = model_input_length // config.local_attn_chunk_length + (
|
||||
model_input_length % config.local_attn_chunk_length != 0
|
||||
)
|
||||
model_input_chunk_len = config.local_attn_chunk_length
|
||||
query_chunk_len = config.local_attn_chunk_length * (
|
||||
1 + config.local_num_chunks_after + config.local_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
min_length // config.local_attn_chunk_length + 1 + idx,
|
||||
model_input_length,
|
||||
prompt_length // config.local_attn_chunk_length + generated_length,
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
num_chunks,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
model_input_chunk_len,
|
||||
query_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
@@ -657,25 +663,29 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
# NOTE (joao): this function is substancially different from the original, the hidden states have different
|
||||
# length in certain conditions
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx
|
||||
seq_len = config.local_attn_chunk_length * (
|
||||
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
|
||||
for generation_length, iter_hidden_states in enumerate(hidden_states):
|
||||
use_cache_this_iter = use_cache and generation_length > 0
|
||||
model_input_length = prompt_length + generation_length
|
||||
model_output_length = config.local_attn_chunk_length * (
|
||||
model_input_length // config.local_attn_chunk_length
|
||||
+ (model_input_length % config.local_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
if use_cache_this_iter:
|
||||
model_output_length = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
@@ -789,37 +799,42 @@ class ReformerLSHAttnModelTest(
|
||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
# NOTE (joao): this function is substancially different from the original, the attention has different
|
||||
# *number* of shapes in certain conditions
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.lsh_attn_chunk_length
|
||||
src_chunk_len = config.lsh_attn_chunk_length * (
|
||||
for generated_length, iter_attentions in enumerate(attentions):
|
||||
use_cache = decoder_past_key_values is not None and generated_length > 0
|
||||
model_input_len = prompt_length + generated_length if not use_cache else 1
|
||||
num_chunks = model_input_len // config.lsh_attn_chunk_length + (
|
||||
model_input_len % config.lsh_attn_chunk_length != 0
|
||||
)
|
||||
model_input_chunk_len = config.lsh_attn_chunk_length
|
||||
query_chunk_len = config.lsh_attn_chunk_length * (
|
||||
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
config.num_hashes,
|
||||
tgt_len,
|
||||
model_input_len,
|
||||
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
batch_size,
|
||||
config.num_attention_heads,
|
||||
num_chunks * config.num_hashes,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
model_input_chunk_len,
|
||||
query_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
@@ -827,25 +842,29 @@ class ReformerLSHAttnModelTest(
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
# NOTE (joao): this function is substancially different from the original, the hidden states have different
|
||||
# length in certain conditions
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx if not use_cache else 1
|
||||
seq_len = config.lsh_attn_chunk_length * (
|
||||
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
|
||||
for generation_length, iter_hidden_states in enumerate(hidden_states):
|
||||
use_cache_this_iter = use_cache and generation_length > 0
|
||||
model_input_length = prompt_length + generation_length
|
||||
model_output_length = config.local_attn_chunk_length * (
|
||||
model_input_length // config.local_attn_chunk_length
|
||||
+ (model_input_length % config.local_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
if use_cache_this_iter:
|
||||
model_output_length = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
|
||||
Reference in New Issue
Block a user