[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -527,48 +527,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_generate_without_input_ids(self):
pass
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, mel, seq_length = input_ids.shape
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
# `input_features`
def test_lm_head_model_random_no_beam_search_generate(self):

View File

@@ -1607,6 +1607,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_generate_compile_model_forward(self):
pass
# TODO (joao, eustache): fix me :)
@unittest.skip(reason="A CUDA exception is thrown when storing extra outputs")
def test_generate_compilation_all_outputs(self):
pass
@require_torch
@require_torchaudio