[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)

* shape checks compatible with static cache

* add test

* tmp

* manually turn on eager attn when we want to output attn

* typo

* generalize to encoder-decoder models

* force compilation on cpu

* tmp commit

* fix static cache shape checks

* models with odd caches

* fix copies

* shorter cache search loop

* use decoder_past_key_values everywhere

* better test variable names and comments

* signature

* rename _check_outputs into _check_generate_outputs

* add comments

* HybridCache future test note
This commit is contained in:
Joao Gante
2025-02-10 17:50:54 +00:00
committed by GitHub
parent 9510ae39d9
commit be2ac0916a
25 changed files with 379 additions and 917 deletions

View File

@@ -581,103 +581,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
self.assertIsNotNone(model)
# overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
use_cache = True # force this to be True in case False is passed
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
internal_batch_size = (
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
super()._check_generate_outputs(
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
)
seq_length = getattr(self.model_tester, "seq_length", None)
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
config = config.text_config if hasattr(config, "text_config") else config
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# in some models we subsample the sequence length in inner layers
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
# scores
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(internal_batch_size, output.logits, config=config)
# Attentions
if self.has_attentions:
if config.is_encoder_decoder:
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, input_batch_size, config, seq_length
)
# decoder
self._check_attentions_for_generate(
internal_batch_size,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
internal_batch_size,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
if config.is_encoder_decoder:
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, input_batch_size, config, seq_length
)
# decoder
self._check_hidden_states_for_generate(
internal_batch_size,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
internal_batch_size,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Past Key Value States
if use_cache:
past_key_values = output.past_key_values
past_sequence_length = output.sequences.shape[-1] - 1
self._check_past_key_values_for_generate(
internal_batch_size,
past_key_values,
seq_length=past_sequence_length,
config=config,
)
# overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
@pytest.mark.generate
def test_left_padding_compatibility(self):