From 365fecb4d0b6c87f20b93561e11c3d4c77938012 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 30 Jan 2025 12:43:00 +0100 Subject: [PATCH] Whisper: fix static cache CI (#35852) * fix * remove overriden method * small change --- src/transformers/generation/utils.py | 17 ++-- .../models/whisper/generation_whisper.py | 2 +- .../models/whisper/modeling_whisper.py | 84 +------------------ tests/models/whisper/test_modeling_whisper.py | 5 +- 4 files changed, 15 insertions(+), 93 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cb6ec15bb9..fed276b323 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -406,23 +406,28 @@ class GenerationMixin: model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) # 4. Create missing `position_ids` on the fly + attention_mask = ( + kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask + ) + attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" + position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" if ( attention_mask is not None - and kwargs.get("position_ids") is None - and "position_ids" in set(inspect.signature(self.forward).parameters.keys()) + and kwargs.get(position_ids_key) is None + and position_ids_key in set(inspect.signature(self.forward).parameters.keys()) ): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below) + kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below) # 5. Slice model inputs if it's an input that should have the same length as `input_ids` - for model_input_name in ["position_ids", "token_type_ids"]: + for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: model_input = kwargs.get(model_input_name) if model_input is not None: if past_key_values is not None: current_input_length = ( model_inputs["inputs_embeds"].shape[1] - if model_inputs["inputs_embeds"] is not None + if model_inputs.get("inputs_embeds") is not None else model_inputs[input_ids_key].shape[1] ) model_input = model_input[:, -current_input_length:] @@ -469,7 +474,7 @@ class GenerationMixin: past_key_values=past_key_values, ) if attention_mask is not None: - model_inputs["attention_mask"] = attention_mask + model_inputs[attention_mask_key] = attention_mask # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). for key, value in kwargs.items(): diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 1b4ecb831b..035b4fb890 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1234,7 +1234,7 @@ class WhisperGenerationMixin(GenerationMixin): def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} - set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs}) + set_inputs({"inputs": segment_input, "input_ids": decoder_input_ids, **extra_kwargs}) @staticmethod def _retrieve_total_input_frames(input_features, input_stride, kwargs): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 150d683e92..f7bbffdbc5 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1255,7 +1255,7 @@ class WhisperDecoder(WhisperPreTrainedModel): ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1) # embed positions if input_ids is not None: @@ -1806,88 +1806,6 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM encoder_attentions=outputs.encoder_attentions, ) - def prepare_inputs_for_generation( - self, - decoder_input_ids, - past_key_values=None, - use_cache=None, - encoder_outputs=None, - attention_mask=None, - decoder_attention_mask=None, - cache_position=None, - **kwargs, - ): - # Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time - # this function needs to be touched, let's try to sort out the commonalities between the two and remove the - # overwrite. - - decoder_position_ids = None - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) - - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - if decoder_position_ids is not None: - decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) - - if cache_position is None: - cache_position = torch.arange( - past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device - ) - elif use_cache: - cache_position = cache_position[-decoder_input_ids.shape[1] :] - - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 - decoder_input_ids = decoder_input_ids.contiguous() - - if ( - isinstance(past_key_values, EncoderDecoderCache) - and ( - isinstance(past_key_values.self_attention_cache, StaticCache) - or isinstance(past_key_values.cross_attention_cache, StaticCache) - ) - and decoder_attention_mask is not None - and decoder_attention_mask.ndim == 2 - ): - batch_size, sequence_length = decoder_input_ids.shape - - decoder_attention_mask = self.get_decoder()._prepare_4d_causal_attention_mask_with_cache_position( - decoder_attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.self_attention_cache.get_max_cache_shape(), - dtype=self.proj_out.weight.dtype, - device=decoder_input_ids.device, - cache_position=cache_position, - batch_size=batch_size, - ) - - return { - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "use_cache": use_cache, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "cache_position": cache_position, - } - class WhisperDecoderWrapper(WhisperPreTrainedModel): """ diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ce30ea4eae..fe41afabf4 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -3323,8 +3323,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): input_features = input_features.to(torch_device) eager_generated_ids = model.generate(input_features, max_new_tokens=64) + # Using statiic cache compiles forward for each decoding step, so we don't have to manually compile model.generation_config.cache_implementation = "static" - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) # compile the forward pass and assert equivalence static_generated_ids = model.generate(input_features, max_new_tokens=64) @@ -3379,9 +3379,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): set_seed(42) eager_generated_ids = model.generate(**inputs, **gen_kwargs) - # compile the forward pass and assert equivalence + # Using statiic cache compiles forward for each decoding step, so we don't have to manually compile model.generation_config.cache_implementation = "static" - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) set_seed(42) static_generated_ids = model.generate(**inputs, **gen_kwargs)