From be2ac0916a7902e1683d708805270142257a254a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 10 Feb 2025 17:50:54 +0000 Subject: [PATCH] [generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993) * shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note --- src/transformers/generation/utils.py | 44 +-- .../modeling_musicgen_melody.py | 29 -- .../qwen2_audio/modeling_qwen2_audio.py | 50 +-- tests/generation/test_utils.py | 303 ++++++++++++------ tests/models/blip_2/test_modeling_blip_2.py | 97 +----- tests/models/cohere2/test_modeling_cohere2.py | 47 +-- tests/models/gemma2/test_modeling_gemma2.py | 47 +-- tests/models/git/test_modeling_git.py | 49 +-- tests/models/idefics/test_modeling_idefics.py | 2 +- .../models/imagegpt/test_modeling_imagegpt.py | 4 +- .../test_modeling_instructblip.py | 97 +----- .../test_modeling_instructblipvideo.py | 97 +----- tests/models/led/test_modeling_led.py | 4 +- tests/models/longt5/test_modeling_longt5.py | 6 +- tests/models/mllama/test_modeling_mllama.py | 25 +- tests/models/moshi/test_modeling_moshi.py | 69 +--- .../pegasus_x/test_modeling_pegasus_x.py | 10 +- .../pix2struct/test_modeling_pix2struct.py | 12 +- .../test_modeling_recurrent_gemma.py | 25 -- .../models/reformer/test_modeling_reformer.py | 105 +++--- .../test_modeling_tf_speech_to_text.py | 42 --- .../whisper/test_modeling_tf_whisper.py | 42 --- tests/models/whisper/test_modeling_whisper.py | 5 + tests/models/xlm/test_modeling_xlm.py | 50 +-- tests/models/xlnet/test_modeling_xlnet.py | 35 +- 25 files changed, 379 insertions(+), 917 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f0f7f2b0b6..a773c4a1d9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -116,6 +116,16 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module +# Variable names used to hold the cache at generation time +ALL_CACHE_NAMES = [ + "past_key_values", # default + "cache_params", # mamba-based models + "state", # rwkv + "mems", # xlnet + "past_buckets_states", # reformer +] + + @dataclass class GenerateDecoderOnlyOutput(ModelOutput): """ @@ -756,21 +766,6 @@ class GenerationMixin: return input_ids, model_kwargs - def _extract_past_from_model_output(self, outputs: ModelOutput): - past_key_values = None - cache_name = "past_key_values" - if "past_key_values" in outputs: - past_key_values = outputs.past_key_values - elif "mems" in outputs: - past_key_values = outputs.mems - elif "past_buckets_states" in outputs: - past_key_values = outputs.past_buckets_states - elif "cache_params" in outputs: - past_key_values = outputs.cache_params - cache_name = "cache_params" - - return cache_name, past_key_values - def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -779,10 +774,15 @@ class GenerationMixin: num_new_tokens: int = 1, ) -> Dict[str, Any]: # update past_key_values keeping its naming used in model code - cache_name, cache = self._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state + for possible_cache_name in ALL_CACHE_NAMES: + if possible_cache_name in outputs: + # TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated + if possible_cache_name in ("past_buckets_states", "mems"): + cache_name = "past_key_values" + else: + cache_name = possible_cache_name + model_kwargs[cache_name] = getattr(outputs, possible_cache_name) + break # update token_type_ids with last value if "token_type_ids" in model_kwargs: @@ -2087,7 +2087,7 @@ class GenerationMixin: # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - different models have a different cache name expected by the model (default = "past_key_values") # - `max_length`, prepared above, is used to determine the maximum cache length - max_cache_length = generation_config.max_length + max_cache_length = generation_config.max_length - 1 if ( inputs_tensor.shape[1] != input_ids_length and model_input_name == "inputs_embeds" @@ -2994,7 +2994,9 @@ class GenerationMixin: next_past_key_values = selected_outputs["past_key_values"] else: - _, next_past_key_values = self._extract_past_from_model_output(outputs) + next_past_key_values = None + for possible_cache_name in ALL_CACHE_NAMES: + next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None) # Do it in-place layer per layer to save memory if isinstance(next_past_key_values, DynamicCache) or ( isinstance(next_past_key_values, EncoderDecoderCache) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index dc0e9b882b..279a7c046c 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -2557,32 +2557,3 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): return outputs else: return output_values - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - model_inputs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - # update past_key_values - cache_name, cache = self._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - # update decoder attention mask - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - model_kwargs["decoder_attention_mask"] = torch.cat( - [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], - dim=-1, - ) - - return model_kwargs diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index a844a67861..320d209313 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -1329,54 +1329,6 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi ) return model_inputs - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - num_new_tokens: int = 1, - ) -> Dict[str, Any]: - # update past_key_values keeping its naming used in model code - cache_name, cache = self._extract_past_from_model_output(outputs) - model_kwargs[cache_name] = cache - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update attention_mask - if getattr(outputs, "attention_mask", None) is not None: - model_kwargs["attention_mask"] = outputs.attention_mask - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - if not is_encoder_decoder: - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - else: - # update decoder attention mask - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - model_kwargs["decoder_attention_mask"] = torch.cat( - [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], - dim=-1, - ) - - if model_kwargs.get("use_cache", True): - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens - else: - past_positions = model_kwargs.pop("cache_position") - new_positions = torch.arange( - past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype - ).to(past_positions.device) - model_kwargs["cache_position"] = torch.cat((past_positions, new_positions)) - return model_kwargs - def _reorder_cache(self, *args, **kwargs): return self.language_model._reorder_cache(*args, **kwargs) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c7c8c7f8c1..6833fd476e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -72,7 +72,14 @@ if is_torch_available(): SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + Cache, + DynamicCache, + EncoderDecoderCache, + HybridCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -473,6 +480,8 @@ class GenerationTesterMixin: def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -499,12 +508,14 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) - self._check_outputs(output_generate, model.config) + self._check_generate_outputs(output_generate, model.config) @pytest.mark.generate def test_greedy_generate_dict_outputs_use_cache(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -531,7 +542,7 @@ class GenerationTesterMixin: output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] ) - self._check_outputs(output_generate, model.config, use_cache=True) + self._check_generate_outputs(output_generate, model.config, use_cache=True) @pytest.mark.generate def test_sample_generate(self): @@ -550,6 +561,8 @@ class GenerationTesterMixin: def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( @@ -577,7 +590,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) - self._check_outputs(output_generate, model.config, num_return_sequences=2) + self._check_generate_outputs(output_generate, model.config, num_return_sequences=2) @pytest.mark.generate def test_beam_search_generate(self): @@ -598,6 +611,8 @@ class GenerationTesterMixin: def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -625,7 +640,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self._check_outputs( + self._check_generate_outputs( output_generate, model.config, num_return_sequences=beam_kwargs["num_return_sequences"], @@ -642,6 +657,8 @@ class GenerationTesterMixin: if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -666,7 +683,7 @@ class GenerationTesterMixin: output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] ) - self._check_outputs( + self._check_generate_outputs( output_generate, model.config, use_cache=True, @@ -721,6 +738,8 @@ class GenerationTesterMixin: def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -750,7 +769,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - self._check_outputs( + self._check_generate_outputs( output_generate, model.config, num_return_sequences=beam_kwargs["num_return_sequences"], @@ -813,6 +832,8 @@ class GenerationTesterMixin: def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -840,7 +861,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self._check_outputs( + self._check_generate_outputs( output_generate, model.config, num_return_sequences=beam_kwargs["num_return_sequences"], @@ -909,6 +930,8 @@ class GenerationTesterMixin: def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() @@ -947,7 +970,7 @@ class GenerationTesterMixin: # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) - self._check_outputs( + self._check_generate_outputs( output_generate, model.config, num_return_sequences=beam_kwargs["num_return_sequences"], @@ -999,6 +1022,8 @@ class GenerationTesterMixin: if not hasattr(config, "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( @@ -1019,7 +1044,7 @@ class GenerationTesterMixin: output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] ) - self._check_outputs(output_generate, model.config, use_cache=True) + self._check_generate_outputs(output_generate, model.config, use_cache=True) @pytest.mark.generate def test_contrastive_generate_low_memory(self): @@ -1205,7 +1230,7 @@ class GenerationTesterMixin: # The two outputs must match and their shape must be as expected self._check_similar_generate_outputs(output_greedy, output_assisted) for output in (output_greedy, output_assisted): - self._check_outputs(output, model.config, use_cache=True) + self._check_generate_outputs(output, model.config, use_cache=True) @pytest.mark.generate def test_prompt_lookup_decoding_matches_greedy_search(self): @@ -1270,7 +1295,7 @@ class GenerationTesterMixin: # The two outputs must match and their shape must be as expected self._check_similar_generate_outputs(output_greedy, output_prompt_lookup) for output in (output_greedy, output_prompt_lookup): - self._check_outputs(output, model.config, use_cache=True) + self._check_generate_outputs(output, model.config, use_cache=True) @pytest.mark.generate def test_dola_decoding_sample(self): @@ -1320,7 +1345,7 @@ class GenerationTesterMixin: "dola_layers": "low", } output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict) - self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False)) + self._check_generate_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False)) @pytest.mark.generate def test_assisted_decoding_sample(self): @@ -1381,7 +1406,7 @@ class GenerationTesterMixin: } output_assisted = model.generate(**generation_kwargs, **inputs_dict) - self._check_outputs(output_assisted, config, use_cache=True) + self._check_generate_outputs(output_assisted, config, use_cache=True) @pytest.mark.generate def test_prompt_lookup_decoding_stops_at_eos(self): @@ -1419,6 +1444,8 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() text_config = config.get_text_config() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise # We want to test only encoder-decoder models if not text_config.is_encoder_decoder: @@ -1765,7 +1792,7 @@ class GenerationTesterMixin: num_hidden_layers = text_config.num_hidden_layers inputs_embeds = model.get_input_embeddings()(input_ids) - max_cache_len += inputs_embeds.shape[1] + max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict) # we should get `max_length` in shape, not `max_length - embeds_length` @@ -2003,7 +2030,7 @@ class GenerationTesterMixin: ) # Check 1: The cache shapes must match the expected shapes - max_cache_len = seq_length + max_new_tokens + max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache text_config = config.text_config if hasattr(config, "text_config") else config head_dim = ( text_config.head_dim @@ -2138,6 +2165,58 @@ class GenerationTesterMixin: for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): self._check_similar_generate_outputs(dynamic_result, compiled_result) + @pytest.mark.generate + def test_generate_compilation_all_outputs(self): + """ + Tests that all optional outputs are behaving as expected when compilation is triggered. + In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") + + config, inputs_dict = self.prepare_config_and_inputs_for_generate() + if self.has_attentions: + config._attn_implementation = "eager" # can't output attentions otherwise + model = model_class(config).to(torch_device).eval() + + # compilation-specific setup + torch.compiler.reset() # prevent cached compilation from being used in the test + has_defined_cache_implementation = model.generation_config.cache_implementation is not None + model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) + if not has_defined_cache_implementation: + model.generation_config.cache_implementation = "static" + + logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) + output_generate = model.generate( + do_sample=False, + num_beams=1, + max_new_tokens=self.max_new_tokens, + min_new_tokens=self.max_new_tokens, + output_attentions=True, + output_hidden_states=True, + output_scores=True, + output_logits=True, + return_dict_in_generate=True, + use_cache=True, + **logits_processor_kwargs, + **inputs_dict, + ) + + # Sanity check: compilation has happened + self.assertTrue(hasattr(model, "_compiled_call")) + + if model.config.is_encoder_decoder: + self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + else: + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + + self._check_generate_outputs(output_generate, model.config, use_cache=True) + @pytest.mark.generate def test_generate_methods_with_logits_to_keep(self): for model_class in self.all_generative_model_classes: @@ -2290,86 +2369,90 @@ class GenerationTesterMixin: # check whether we still need the overwrites self._test_attention_implementation("flash_attention_2") - def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): input_batch_size = int(output.sequences.shape[0] / num_return_sequences) internal_batch_size = ( input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences ) - seq_length = getattr(self.model_tester, "seq_length", None) - seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) - seq_length = getattr(self.model_tester, "text_seq_length", seq_length) + prompt_length = getattr(self.model_tester, "seq_length", None) + prompt_length = getattr(self.model_tester, "encoder_seq_length", prompt_length) + prompt_length = getattr(self.model_tester, "text_seq_length", prompt_length) config = config.text_config if hasattr(config, "text_config") else config - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length + generated_length = ( + output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length ) + decoder_past_key_values = getattr(output, "past_key_values", None) + if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache): + decoder_past_key_values = decoder_past_key_values.self_attention_cache # in some models we subsample the sequence length in inner layers if hasattr(self.model_tester, "get_subsampled_output_lengths"): - seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) + prompt_length = self.model_tester.get_subsampled_output_lengths(prompt_length) # scores - self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) + self._check_scores( + batch_size=internal_batch_size, scores=output.scores, generated_length=generated_length, config=config + ) # unprocessed logits - self._check_logits(internal_batch_size, output.logits, config=config) + self._check_logits(batch_size=internal_batch_size, logits=output.logits, config=config) # Attentions if self.has_attentions: if config.is_encoder_decoder: # encoder self._check_encoder_attention_for_generate( - output.encoder_attentions, input_batch_size, config, seq_length + attentions=output.encoder_attentions, + batch_size=input_batch_size, + config=config, + prompt_length=prompt_length, ) # decoder self._check_attentions_for_generate( - internal_batch_size, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], + batch_size=internal_batch_size, + attentions=output.decoder_attentions, + prompt_length=1, # the BOS token + output_length=output.sequences.shape[-1], config=config, - use_cache=use_cache, + decoder_past_key_values=decoder_past_key_values, ) else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 self._check_attentions_for_generate( - internal_batch_size, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], + batch_size=internal_batch_size, + attentions=output.attentions, + prompt_length=prompt_length, + output_length=output.sequences.shape[-1], config=config, - use_cache=use_cache, + decoder_past_key_values=decoder_past_key_values, ) # Hidden States if config.is_encoder_decoder: # encoder self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, input_batch_size, config, seq_length + hidden_states=output.encoder_hidden_states, + batch_size=input_batch_size, + config=config, + prompt_length=prompt_length, ) - # decoder self._check_hidden_states_for_generate( - internal_batch_size, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], + batch_size=internal_batch_size, + hidden_states=output.decoder_hidden_states, + prompt_length=1, # the BOS token + output_length=output.sequences.shape[-1], config=config, use_cache=use_cache, ) else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 self._check_hidden_states_for_generate( - internal_batch_size, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], + batch_size=internal_batch_size, + hidden_states=output.hidden_states, + prompt_length=prompt_length, + output_length=output.sequences.shape[-1], config=config, use_cache=use_cache, ) @@ -2396,59 +2479,73 @@ class GenerationTesterMixin: ) if has_standard_cache: if use_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 + cache_length = output.sequences.shape[-1] - 1 self._check_past_key_values_for_generate( - internal_batch_size, - past_key_values, - seq_length=past_sequence_length, + batch_size=internal_batch_size, + decoder_past_key_values=decoder_past_key_values, + cache_length=cache_length, config=config, ) elif use_cache is False: - self.assertTrue(output.past_key_values is None) + self.assertTrue(decoder_past_key_values is None) - def _check_scores(self, batch_size, scores, length, config): + def _check_scores(self, batch_size, scores, generated_length, config): vocab_size = config.get_text_config(decoder=True).vocab_size expected_shape = (batch_size, vocab_size) self.assertIsInstance(scores, tuple) - self.assertEqual(len(scores), length) + self.assertEqual(len(scores), generated_length) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) - def _check_logits(self, batch_size, scores, config): + def _check_logits(self, batch_size, logits, config): vocab_size = config.get_text_config(decoder=True).vocab_size - self.assertIsInstance(scores, tuple) - self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) + self.assertIsInstance(logits, tuple) + self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits)) # vocabulary difference equal to one (imagegptmodel?) or zero (all other models) - vocab_diff = vocab_size - scores[0].shape[-1] + vocab_diff = vocab_size - logits[0].shape[-1] self.assertTrue(vocab_diff in [0, 1]) - self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) + self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits)) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): self.assertIsInstance(attentions, tuple) self.assertListEqual( [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(attentions), (output_length - prompt_length)) - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - src_len = min_length + idx + use_cache = decoder_past_key_values is not None + has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache)) + + # When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new + # token(s) + # NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more + # elaborate checks + for generated_length, iter_attentions in enumerate(attentions): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + query_length = ( + prompt_length + generated_length + if not has_static_cache + else decoder_past_key_values.get_max_cache_shape() + ) expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, - tgt_len, - src_len, + model_input_length, + query_length, ) # check attn size self.assertListEqual( [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) ) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): - encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): + encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length) self.assertIsInstance(attentions, tuple) self.assertListEqual( [layer_attentions.shape for layer_attentions in attentions], @@ -2456,71 +2553,75 @@ class GenerationTesterMixin: ) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], [True] * len(hidden_states), ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length + idx if not use_cache else 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + # When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the + # new token(s) + # NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more + # elaborate checks + for generated_length, iter_hidden_states in enumerate(hidden_states): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + expected_shape = (batch_size, model_input_length, config.hidden_size) # check hidden size self.assertListEqual( [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], [expected_shape] * len(iter_hidden_states), ) - def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): - encoder_expected_shape = (batch_size, seq_length, config.hidden_size) + def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length): + encoder_expected_shape = (batch_size, prompt_length, config.hidden_size) self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [layer_hidden_states.shape for layer_hidden_states in hidden_states], [encoder_expected_shape] * len(hidden_states), ) - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): - self.assertIsInstance(past_key_values, (tuple, Cache)) - - # Encoder-decoder models: pull and verify the decoder cache - if isinstance(past_key_values, EncoderDecoderCache): - past_key_values = past_key_values.self_attention_cache + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) # (batch, head, seq_length, head_features) expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, - seq_length, + cache_length, config.hidden_size // config.num_attention_heads, ) - if isinstance(past_key_values, Cache): + if isinstance(decoder_past_key_values, Cache): self.assertListEqual( - [key_tensor.shape for key_tensor in past_key_values.key_cache], - [expected_shape] * len(past_key_values.key_cache), + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), ) self.assertListEqual( - [value_tensor.shape for value_tensor in past_key_values.value_cache], - [expected_shape] * len(past_key_values.value_cache), + [value_tensor.shape for value_tensor in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), ) # Legacy cache format checks. This branch should be removed when all models use `Cache` by default else: self.assertListEqual( - [isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values], - [True] * len(past_key_values), + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], + [True] * len(decoder_past_key_values), ) # check shape key, value self.assertListEqual( - [layer_past_key_values[0].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), + [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), ) self.assertListEqual( - [layer_past_key_values[1].shape for layer_past_key_values in past_key_values], - [expected_shape] * len(past_key_values), + [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], + [expected_shape] * len(decoder_past_key_values), ) def _check_sequence_inside_sequence(self, tensor_1, tensor_2): diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index 628eaba738..775971cf28 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -723,103 +723,12 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT self.assertIsNotNone(model) # overwrite because BLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format - def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): use_cache = True # force this to be True in case False is passed - - input_batch_size = int(output.sequences.shape[0] / num_return_sequences) - internal_batch_size = ( - input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences + super()._check_generate_outputs( + output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams ) - seq_length = getattr(self.model_tester, "seq_length", None) - seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) - seq_length = getattr(self.model_tester, "text_seq_length", seq_length) - - config = config.text_config if hasattr(config, "text_config") else config - - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # in some models we subsample the sequence length in inner layers - if hasattr(self.model_tester, "get_subsampled_output_lengths"): - seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) - - # scores - self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) - - # unprocessed logits - self._check_logits(internal_batch_size, output.logits, config=config) - - # Attentions - if self.has_attentions: - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate( - output.encoder_attentions, input_batch_size, config, seq_length - ) - # decoder - self._check_attentions_for_generate( - internal_batch_size, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - internal_batch_size, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - if config.is_encoder_decoder: - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, input_batch_size, config, seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - internal_batch_size, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_hidden_states_for_generate( - internal_batch_size, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Past Key Value States - if use_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - internal_batch_size, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) - # overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present @pytest.mark.generate def test_left_padding_compatibility(self): diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 81ea53b49f..881856ea70 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -20,7 +20,7 @@ from packaging import version from parameterized import parameterized from pytest import mark -from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, HybridCache, is_torch_available, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_torch_available, pipeline from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( require_flash_attn, @@ -135,51 +135,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): def test_generate_continue_from_inputs_embeds(self): pass - # overwrite because HybridCache has fixed length for key/values - def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - self.assertIsInstance(attentions, tuple) - self.assertListEqual( - [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) - ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) - - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - src_len = min_length + idx if not use_cache else max_length - - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) - # check attn size - self.assertListEqual( - [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) - ) - - # overwrite because HybridCache has fixed length for key/values - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): - self.assertIsInstance(past_key_values, HybridCache) - - # check shape key, value (batch, head, max_seq_length, head_features) - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_hidden_layers = config.num_hidden_layers - - # we should get `max_length` in shape, not `max_length - embeds_length` - # `+1` because the test in Mixin subtracts 1 which is needed for tuple cache - static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim) - static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean] - self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape) - @unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different") def test_sdpa_equivalence(self): pass diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index a0563aed90..8d02565b4b 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -20,7 +20,7 @@ from packaging import version from parameterized import parameterized from pytest import mark -from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( require_flash_attn, @@ -150,51 +150,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): def test_generate_continue_from_inputs_embeds(self): pass - # overwrite because HybridCache has fixed length for key/values - def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - self.assertIsInstance(attentions, tuple) - self.assertListEqual( - [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) - ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) - - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - src_len = min_length + idx if not use_cache else max_length - - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) - # check attn size - self.assertListEqual( - [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) - ) - - # overwrite because HybridCache has fixed length for key/values - def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): - self.assertIsInstance(past_key_values, HybridCache) - - # check shape key, value (batch, head, max_seq_length, head_features) - head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - num_hidden_layers = config.num_hidden_layers - - # we should get `max_length` in shape, not `max_length - embeds_length` - # `+1` because the test in Mixin subtracts 1 which is needed for tuple cache - static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim) - static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean] - self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers) - self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape) - @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") def test_sdpa_equivalence(self): pass diff --git a/tests/models/git/test_modeling_git.py b/tests/models/git/test_modeling_git.py index ff9c086bdb..e4251c700e 100644 --- a/tests/models/git/test_modeling_git.py +++ b/tests/models/git/test_modeling_git.py @@ -456,51 +456,26 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.model_tester.create_and_check_model(*config_and_inputs) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): # GIT attention shape depends on image inputs, overwrite - self.assertIsInstance(attentions, tuple) - self.assertListEqual( - [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) - ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1) - - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx + image_length if not use_cache else 1 - src_len = min_length + idx + image_length - - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) - # check attn size - self.assertListEqual( - [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) - ) + prompt_length += image_length + output_length += image_length + super()._check_attentions_for_generate( + batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values + ) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): # GIT attention shape depends on image inputs, overwrite - self.assertIsInstance(hidden_states, tuple) - self.assertListEqual( - [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], - [True] * len(hidden_states), - ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1) - - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length + idx + image_length if not use_cache else 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) - # check hidden size - self.assertListEqual( - [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], - [expected_shape] * len(iter_hidden_states), - ) + prompt_length += image_length + output_length += image_length + super()._check_hidden_states_for_generate( + batch_size, hidden_states, prompt_length, output_length, config, use_cache=use_cache + ) @slow def test_model_from_pretrained(self): diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index cc9efc967d..1306dc50d9 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -815,7 +815,7 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni ) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): """ Overwrite from generation tests because Idefics has only SDPA layers. diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index a8bcb8d180..6d96c444c3 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -251,10 +251,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM return inputs_dict # we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings - def _check_scores(self, batch_size, scores, length, config): + def _check_scores(self, batch_size, scores, generated_length, config): expected_shape = (batch_size, config.vocab_size - 1) self.assertIsInstance(scores, tuple) - self.assertEqual(len(scores), length) + self.assertEqual(len(scores), generated_length) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) @run_test_using_subprocess diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index d472274fab..434784f05a 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -565,103 +565,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene self.assertIsNotNone(model) # overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format - def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): use_cache = True # force this to be True in case False is passed - - input_batch_size = int(output.sequences.shape[0] / num_return_sequences) - internal_batch_size = ( - input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences + super()._check_generate_outputs( + output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams ) - seq_length = getattr(self.model_tester, "seq_length", None) - seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) - seq_length = getattr(self.model_tester, "text_seq_length", seq_length) - - config = config.text_config if hasattr(config, "text_config") else config - - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # in some models we subsample the sequence length in inner layers - if hasattr(self.model_tester, "get_subsampled_output_lengths"): - seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) - - # scores - self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) - - # unprocessed logits - self._check_logits(internal_batch_size, output.logits, config=config) - - # Attentions - if self.has_attentions: - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate( - output.encoder_attentions, input_batch_size, config, seq_length - ) - # decoder - self._check_attentions_for_generate( - internal_batch_size, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - internal_batch_size, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - if config.is_encoder_decoder: - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, input_batch_size, config, seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - internal_batch_size, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_hidden_states_for_generate( - internal_batch_size, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Past Key Value States - if use_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - internal_batch_size, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) - # overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present @pytest.mark.generate def test_left_padding_compatibility(self): diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 76c5c11de2..e8ed52b723 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -581,103 +581,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( self.assertIsNotNone(model) # overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format - def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): use_cache = True # force this to be True in case False is passed - - input_batch_size = int(output.sequences.shape[0] / num_return_sequences) - internal_batch_size = ( - input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences + super()._check_generate_outputs( + output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams ) - seq_length = getattr(self.model_tester, "seq_length", None) - seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) - seq_length = getattr(self.model_tester, "text_seq_length", seq_length) - - config = config.text_config if hasattr(config, "text_config") else config - - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # in some models we subsample the sequence length in inner layers - if hasattr(self.model_tester, "get_subsampled_output_lengths"): - seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) - - # scores - self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config) - - # unprocessed logits - self._check_logits(internal_batch_size, output.logits, config=config) - - # Attentions - if self.has_attentions: - if config.is_encoder_decoder: - # encoder - self._check_encoder_attention_for_generate( - output.encoder_attentions, input_batch_size, config, seq_length - ) - # decoder - self._check_attentions_for_generate( - internal_batch_size, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - attentions = output.attentions if not use_cache else output.attentions[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_attentions_for_generate( - internal_batch_size, - attentions=attentions, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - if config.is_encoder_decoder: - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, input_batch_size, config, seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - internal_batch_size, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - else: - # if use_cache first input is equal to no use_cache, so skip here - hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:] - min_length = seq_length if not use_cache else seq_length + 1 - self._check_hidden_states_for_generate( - internal_batch_size, - hidden_states, - min_length=min_length, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Past Key Value States - if use_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - internal_batch_size, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) - # overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present @pytest.mark.generate def test_left_padding_compatibility(self): diff --git a/tests/models/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py index 3d21fa0a69..b6a3db5c94 100644 --- a/tests/models/led/test_modeling_led.py +++ b/tests/models/led/test_modeling_led.py @@ -468,12 +468,12 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ], ) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): # overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape encoder_expected_shape = ( batch_size, config.num_attention_heads, - seq_length, + prompt_length, self.model_tester.attention_window // 2 * 2 + 1, ) self.assertIsInstance(attentions, tuple) diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index c2c2563b55..a166a6dab7 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -785,7 +785,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix [self.model_tester.num_attention_heads, block_len, 3 * block_len], ) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): block_len = getattr(self.model_tester, "block_len", None) encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len) self.assertIsInstance(attentions, tuple) @@ -920,10 +920,10 @@ class LongT5TGlobalModelTest(LongT5ModelTest): [self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len], ) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): block_len = getattr(self.model_tester, "block_len", None) global_block_size = getattr(self.model_tester, "global_block_size", None) - global_seq_length = seq_length // global_block_size + global_seq_length = prompt_length // global_block_size encoder_expected_shape = ( batch_size, 2, diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index e1ded5e934..4e4c4636b7 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -323,32 +323,37 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester torch.testing.assert_close(out_embeds, out_ids) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): # Mllama has cross attention layers and those have a different shape than normal attention layers self.assertIsInstance(attentions, tuple) self.assertListEqual( [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(attentions), (output_length - prompt_length)) cross_attention_layers = self.model_tester.text_config["cross_attention_layers"] + use_cache = decoder_past_key_values is not None - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - src_len = min_length + idx + for generated_length, iter_attentions in enumerate(attentions): + # regardless of using cache, the first forward pass will have the full prompt as input + if use_cache and generated_length > 0: + model_input_length = 1 + else: + model_input_length = prompt_length + generated_length + query_length = prompt_length + generated_length expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, - tgt_len, - src_len, + model_input_length, + query_length, ) expected_shape_cross = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, - tgt_len, + model_input_length, self.model_tester.image_length, ) diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index 09278f0d24..9eb0eaa4d4 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -575,77 +575,12 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): return config, filtered_inputs_dict - def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): + def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` - self.assertIsInstance(hidden_states, tuple) - self.assertListEqual( - [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], - [True] * len(hidden_states), - ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) - - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length if idx == 0 else 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) - # check hidden size - self.assertListEqual( - [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], - [expected_shape] * len(iter_hidden_states), - ) - - def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1): - # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` - super()._check_outputs( + super()._check_generate_outputs( output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams ) - def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` - self.assertIsInstance(hidden_states, tuple) - self.assertListEqual( - [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], - [True] * len(hidden_states), - ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) - - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) - # check hidden size - self.assertListEqual( - [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], - [expected_shape] * len(iter_hidden_states), - ) - - def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - # Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True` - self.assertIsInstance(attentions, tuple) - self.assertListEqual( - [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) - ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) - - for idx, iter_attentions in enumerate(attentions): - tgt_len = 1 - src_len = min_length + idx - - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) - # check attn size - self.assertListEqual( - [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) - ) - def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 2463b7ab26..2c3b74edd1 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -399,11 +399,11 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ], ) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): encoder_expected_shape = ( batch_size, config.num_attention_heads, - math.ceil(seq_length / config.block_size), + math.ceil(prompt_length / config.block_size), config.block_size, config.block_size + config.num_global_tokens, ) @@ -413,8 +413,8 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM [encoder_expected_shape] * len(attentions), ) - def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): - encoder_expected_shape = (batch_size, self.round_up(seq_length, config.block_size), config.hidden_size) + def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length): + encoder_expected_shape = (batch_size, self.round_up(prompt_length, config.block_size), config.hidden_size) self.assertIsInstance(hidden_states, tuple) # Only the last layer will have the hidden states truncated back to token level self.assertListEqual( @@ -424,7 +424,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM # Only the last layer will have the hidden states truncated back to token level self.assertEqual( hidden_states[-1][0].shape, - (batch_size, seq_length, config.hidden_size), + (batch_size, prompt_length, config.hidden_size), ) def test_hidden_states_output(self): diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index adec2c893a..3b051db37c 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -753,20 +753,20 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) - def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length): + def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length): # overwrite because # pix2struct seq length depends on image inputs - seq_length = self.model_tester.max_patches - encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length) + prompt_length = self.model_tester.max_patches + encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length) self.assertIsInstance(attentions, tuple) self.assertListEqual( [layer_attentions.shape for layer_attentions in attentions], [encoder_expected_shape] * len(attentions), ) - def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length): + def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length): # overwrite because # pix2struct seq length depends on image inputs - seq_length = self.model_tester.max_patches - encoder_expected_shape = (batch_size, seq_length, config.hidden_size) + prompt_length = self.model_tester.max_patches + encoder_expected_shape = (batch_size, prompt_length, config.hidden_size) self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [layer_hidden_states.shape for layer_hidden_states in hidden_states], diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 985115d770..a7cfc2a04f 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -367,9 +367,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_training_gradient_checkpointing_use_reentrant_false(self): pass - def _check_attentions_for_generate(self, *args, **kwargs): - return True # Model does not return attention - @unittest.skip(reason="Past key values are not returned") def test_prompt_lookup_decoding_matches_greedy_search(self): pass @@ -382,9 +379,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_model_parallel_beam_search(self): pass - def _check_past_key_values_for_generate(self, *args, **kwargs): - return True - @unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported") def test_assisted_decoding_matches_greedy_search(self): pass @@ -397,25 +391,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT def test_assisted_decoding_sample(self): pass - def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 - ): - self.assertIsInstance(hidden_states, tuple) - self.assertListEqual( - [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], - [True] * len(hidden_states), - ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) - - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length + idx if not use_cache else 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) - # check hidden size - self.assertListEqual( - [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], - [expected_shape] * len(iter_hidden_states), - ) - @unittest.skip(reason="TODO @arthurzucker not super important and failing.") def test_initialization(self): pass diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index fde19b7454..24b59b2f1b 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -620,36 +620,42 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod self.assertIsNotNone(model) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): + # NOTE (joao): this function is substancially different from the original, the attention has different + # *number* of shapes in certain conditions self.assertIsInstance(attentions, tuple) self.assertListEqual( [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(attentions), (output_length - prompt_length)) - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0) - tgt_chunk_len = config.local_attn_chunk_length - src_chunk_len = config.local_attn_chunk_length * ( + for generated_length, iter_attentions in enumerate(attentions): + use_cache = decoder_past_key_values is not None and generated_length > 0 + + model_input_length = prompt_length + generated_length if not use_cache else 1 + num_chunks = model_input_length // config.local_attn_chunk_length + ( + model_input_length % config.local_attn_chunk_length != 0 + ) + model_input_chunk_len = config.local_attn_chunk_length + query_chunk_len = config.local_attn_chunk_length * ( 1 + config.local_num_chunks_after + config.local_num_chunks_before ) if use_cache: expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, - tgt_len, - min_length // config.local_attn_chunk_length + 1 + idx, + model_input_length, + prompt_length // config.local_attn_chunk_length + generated_length, ) else: expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, num_chunks, - tgt_chunk_len, - src_chunk_len, + model_input_chunk_len, + query_chunk_len, ) # check attn size self.assertListEqual( @@ -657,25 +663,29 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod ) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): + # NOTE (joao): this function is substancially different from the original, the hidden states have different + # length in certain conditions self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], [True] * len(hidden_states), ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length + idx - seq_len = config.local_attn_chunk_length * ( - seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0) + for generation_length, iter_hidden_states in enumerate(hidden_states): + use_cache_this_iter = use_cache and generation_length > 0 + model_input_length = prompt_length + generation_length + model_output_length = config.local_attn_chunk_length * ( + model_input_length // config.local_attn_chunk_length + + (model_input_length % config.local_attn_chunk_length != 0) ) - if use_cache: - seq_len = 1 + if use_cache_this_iter: + model_output_length = 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + expected_shape = (batch_size, model_output_length, config.hidden_size) # check hidden size self.assertListEqual( [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], @@ -789,37 +799,42 @@ class ReformerLSHAttnModelTest( self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): + # NOTE (joao): this function is substancially different from the original, the attention has different + # *number* of shapes in certain conditions self.assertIsInstance(attentions, tuple) self.assertListEqual( [isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions) ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(attentions), (output_length - prompt_length)) - for idx, iter_attentions in enumerate(attentions): - tgt_len = min_length + idx if not use_cache else 1 - num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0) - tgt_chunk_len = config.lsh_attn_chunk_length - src_chunk_len = config.lsh_attn_chunk_length * ( + for generated_length, iter_attentions in enumerate(attentions): + use_cache = decoder_past_key_values is not None and generated_length > 0 + model_input_len = prompt_length + generated_length if not use_cache else 1 + num_chunks = model_input_len // config.lsh_attn_chunk_length + ( + model_input_len % config.lsh_attn_chunk_length != 0 + ) + model_input_chunk_len = config.lsh_attn_chunk_length + query_chunk_len = config.lsh_attn_chunk_length * ( 1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before ) if use_cache: expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, config.num_hashes, - tgt_len, + model_input_len, config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before), ) else: expected_shape = ( - batch_size * num_beam_groups, + batch_size, config.num_attention_heads, num_chunks * config.num_hashes, - tgt_chunk_len, - src_chunk_len, + model_input_chunk_len, + query_chunk_len, ) # check attn size self.assertListEqual( @@ -827,25 +842,29 @@ class ReformerLSHAttnModelTest( ) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): + # NOTE (joao): this function is substancially different from the original, the hidden states have different + # length in certain conditions self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states], [True] * len(hidden_states), ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) - for idx, iter_hidden_states in enumerate(hidden_states): - seq_len = min_length + idx if not use_cache else 1 - seq_len = config.lsh_attn_chunk_length * ( - seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0) + for generation_length, iter_hidden_states in enumerate(hidden_states): + use_cache_this_iter = use_cache and generation_length > 0 + model_input_length = prompt_length + generation_length + model_output_length = config.local_attn_chunk_length * ( + model_input_length // config.local_attn_chunk_length + + (model_input_length % config.local_attn_chunk_length != 0) ) - if use_cache: - seq_len = 1 + if use_cache_this_iter: + model_output_length = 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + expected_shape = (batch_size, model_output_length, config.hidden_size) # check hidden size self.assertListEqual( [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], diff --git a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py index 26fee7d93c..ef9a2b33bc 100644 --- a/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_tf_speech_to_text.py @@ -416,48 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T def test_generate_without_input_ids(self): pass - def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): - batch_size, seq_length = input_ids.shape[:2] - subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) - num_sequences_in_output = batch_size * num_return_sequences - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # scores - self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) - - # Attentions - # encoder - self._check_encoder_attention_for_generate( - output.encoder_attentions, batch_size, config, subsampled_seq_length - ) - # decoder - self._check_attentions_for_generate( - num_sequences_in_output, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, batch_size, config, subsampled_seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - num_sequences_in_output, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is # `input_features` def test_lm_head_model_random_no_beam_search_generate(self): diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 504b6174fc..7aacf51719 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -527,48 +527,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC def test_generate_without_input_ids(self): pass - def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): - batch_size, mel, seq_length = input_ids.shape - subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) - num_sequences_in_output = batch_size * num_return_sequences - gen_len = ( - output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length - ) - - # scores - self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) - - # Attentions - # encoder - self._check_encoder_attention_for_generate( - output.encoder_attentions, batch_size, config, subsampled_seq_length - ) - # decoder - self._check_attentions_for_generate( - num_sequences_in_output, - output.decoder_attentions, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - - # Hidden States - # encoder - self._check_encoder_hidden_states_for_generate( - output.encoder_hidden_states, batch_size, config, subsampled_seq_length - ) - - # decoder - self._check_hidden_states_for_generate( - num_sequences_in_output, - output.decoder_hidden_states, - min_length=1, - max_length=output.sequences.shape[-1], - config=config, - use_cache=use_cache, - ) - # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is # `input_features` def test_lm_head_model_random_no_beam_search_generate(self): diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index fe41afabf4..916517add7 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1607,6 +1607,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_generate_compile_model_forward(self): pass + # TODO (joao, eustache): fix me :) + @unittest.skip(reason="A CUDA exception is thrown when storing extra outputs") + def test_generate_compilation_all_outputs(self): + pass + @require_torch @require_torchaudio diff --git a/tests/models/xlm/test_modeling_xlm.py b/tests/models/xlm/test_modeling_xlm.py index 556f97c0b2..d2eefced08 100644 --- a/tests/models/xlm/test_modeling_xlm.py +++ b/tests/models/xlm/test_modeling_xlm.py @@ -473,50 +473,24 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): - self.assertIsInstance(attentions, tuple) - self.assertListEqual( - [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + # adds PAD dummy token, expected shape is off by 1 + prompt_length += 1 + output_length += 1 + super()._check_attentions_for_generate( + batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) - - for idx, iter_attentions in enumerate(attentions): - # adds PAD dummy token - tgt_len = min_length + idx + 1 - src_len = min_length + idx + 1 - - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) - # check attn size - self.assertListEqual( - [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) - ) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): - self.assertIsInstance(hidden_states, tuple) - self.assertListEqual( - [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], - [True] * len(hidden_states), + # adds PAD dummy token, expected shape is off by 1 + prompt_length += 1 + output_length += 1 + super()._check_hidden_states_for_generate( + batch_size, hidden_states, prompt_length, output_length, config, use_cache ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) - - for idx, iter_hidden_states in enumerate(hidden_states): - # adds PAD dummy token - seq_len = min_length + idx + 1 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) - # check hidden size - self.assertListEqual( - [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states], - [expected_shape] * len(iter_hidden_states), - ) - pass @slow def test_model_from_pretrained(self): diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py index 630c8d2e63..4636efed10 100644 --- a/tests/models/xlnet/test_modeling_xlnet.py +++ b/tests/models/xlnet/test_modeling_xlnet.py @@ -636,57 +636,52 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi weight.data.fill_(3) def _check_hidden_states_for_generate( - self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False ): self.assertIsInstance(hidden_states, tuple) self.assertListEqual( [isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states], [True] * len(hidden_states), ) - self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) - for idx, iter_hidden_states in enumerate(hidden_states): + for generated_length, iter_hidden_states in enumerate(hidden_states): # check hidden size for i, layer_hidden_states in enumerate(iter_hidden_states): # every 2nd tensor is from extra stream if i % 2 != 0: - seq_len = 1 + model_output_length = 1 else: # for first item dummy PAD token is appended so need one more # else offset+dummy_token when using cache - seq_len = (min_length + 1) if idx == 0 else 3 + model_output_length = (prompt_length + 1) if generated_length == 0 else 3 - expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size) + expected_shape = (batch_size, model_output_length, config.hidden_size) self.assertEqual(layer_hidden_states.shape, expected_shape) def _check_attentions_for_generate( - self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values ): self.assertIsInstance(attentions, tuple) self.assertListEqual( [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) ) - self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + self.assertEqual(len(attentions), (output_length - prompt_length)) - for idx, attentions_item in enumerate(attentions): + for generated_length, attentions_item in enumerate(attentions): for iter_attentions in attentions_item: - tgt_len = min_length + model_input_length = prompt_length # for first item dummy PAD token is appended so need one more # every token after consists of offset+dummy_token length when using cache - if idx == 0: - tgt_len += 1 + if generated_length == 0: + model_input_length += 1 else: - tgt_len = 3 + model_input_length = 3 - src_len = min_length + idx + 1 + query_length = prompt_length + generated_length + 1 - expected_shape = ( - batch_size * num_beam_groups, - config.num_attention_heads, - tgt_len, - src_len, - ) + expected_shape = (batch_size, config.num_attention_heads, model_input_length, query_length) # check attn size self.assertListEqual( [layer_attention.shape for layer_attention in iter_attentions],