[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -416,48 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
|
||||
Reference in New Issue
Block a user