[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -785,7 +785,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
block_len = getattr(self.model_tester, "block_len", None)
|
||||
encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
@@ -920,10 +920,10 @@ class LongT5TGlobalModelTest(LongT5ModelTest):
|
||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
block_len = getattr(self.model_tester, "block_len", None)
|
||||
global_block_size = getattr(self.model_tester, "global_block_size", None)
|
||||
global_seq_length = seq_length // global_block_size
|
||||
global_seq_length = prompt_length // global_block_size
|
||||
encoder_expected_shape = (
|
||||
batch_size,
|
||||
2,
|
||||
|
||||
Reference in New Issue
Block a user