[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -565,103 +565,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
use_cache = True # force this to be True in case False is passed
|
||||
|
||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||
internal_batch_size = (
|
||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
||||
super()._check_generate_outputs(
|
||||
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||
)
|
||||
|
||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
||||
|
||||
config = config.text_config if hasattr(config, "text_config") else config
|
||||
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# in some models we subsample the sequence length in inner layers
|
||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
|
||||
# scores
|
||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
||||
|
||||
# unprocessed logits
|
||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if self.has_attentions:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, input_batch_size, config, seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
internal_batch_size,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
internal_batch_size,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
internal_batch_size,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_hidden_states_for_generate(
|
||||
internal_batch_size,
|
||||
hidden_states,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Past Key Value States
|
||||
if use_cache:
|
||||
past_key_values = output.past_key_values
|
||||
past_sequence_length = output.sequences.shape[-1] - 1
|
||||
self._check_past_key_values_for_generate(
|
||||
internal_batch_size,
|
||||
past_key_values,
|
||||
seq_length=past_sequence_length,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
|
||||
Reference in New Issue
Block a user