From badc71b9f604ca910bb87a43979c795eaf6e7d64 Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Wed, 28 May 2025 13:32:38 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=B4[`Attention`]=20Attention=20refacto?= =?UTF-8?q?r=20for=20Whisper-based=20models=20(#38235)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * start refactoring whisper * revert for now * first step * carry over attn fixes * check if this works * whisper has an off by one somewhere - cutting mask in any interface * make it based on interface * remove some tests that were skipped but now work * some fixes for whisper tests * interface changes * change the order of fix * some attention adjustments for eager + TP * fix scaling * mask changes * why does whisper contain those extra seq lens? * fix from config for fa2 as input_ids is invalid * fix another test * another fix * disable flex attn due to compile issues * copies and refactor for qwen audio since it somewhat relies on whisper * fix scaling and smaller things * retrigger * new new interface version + more fixups * adjust qwen * add comment * forgot this one * change copies as whisper cuts on the mask * add guard * add flex attention * switch to new mask function + add skips for torchscript * remove old api with cache position * last changes? * trigger ci --- src/transformers/cache_utils.py | 13 + src/transformers/generation/logits_process.py | 4 + src/transformers/generation/utils.py | 2 +- .../qwen2_audio/modeling_qwen2_audio.py | 246 ++------- .../models/whisper/generation_whisper.py | 4 + .../models/whisper/modeling_whisper.py | 496 +++--------------- .../qwen2_audio/test_modeling_qwen2_audio.py | 4 + tests/models/whisper/test_modeling_whisper.py | 55 +- tests/test_modeling_common.py | 22 +- 9 files changed, 200 insertions(+), 646 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 286d03a559..1a3ba7f8df 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1716,6 +1716,19 @@ class EncoderDecoderCache(Cache): self.self_attention_cache.batch_select_indices(indices) self.cross_attention_cache.batch_select_indices(indices) + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + return self.self_attention_cache.get_max_cache_shape() + + def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx) + class HybridCache(Cache): """ diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 6e0f0154ab..a4e8b5eda0 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -2051,6 +2051,10 @@ class WhisperNoSpeechDetection(LogitsProcessor): self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} self.inputs["input_features"] = self.inputs.pop("inputs") + # Whisper encoder-decoder does not accept the input_ids as input + if "input_ids" not in inspect.signature(self.model.forward).parameters: + self.inputs.pop("input_ids", None) + @property def no_speech_prob(self): return self._no_speech_prob diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 784b2d15b6..713d57a899 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -636,7 +636,7 @@ class GenerationMixin(ContinuousMixin): and attention_mask is not None and attention_mask.ndim == 2 ): - if model_inputs["inputs_embeds"] is not None: + if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape else: batch_size, sequence_length = model_inputs[input_ids_key].shape[:2] diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 9ac1fd008a..1f7b144798 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,18 +25,13 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, logging from ..auto import AutoModel, AutoModelForCausalLM from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) @@ -82,6 +77,37 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput): attention_mask: Optional[torch.FloatTensor] = None +# Copied from transformers.models.whisper.modeling_whisper.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None and attention_mask.ndim == 4: + attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]] + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Qwen2AudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -135,219 +161,51 @@ class Qwen2AudioAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" bsz, tgt_len, _ = hidden_states.size() - # get query proj + # Scaling is susceptible to floating point arithmetics' inprecisions + # which can lead to different results (this is dependent from model + # to model, e.g. whisper is one such case). We therefore keep the + # original order of scaling to follow the original implementation + # and enforce no scaling (1.0) in the attention call below. query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_probs, value_states) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights, None - - -class Qwen2AudioFlashAttention2(Qwen2AudioAttention): - """ - Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Qwen2AudioFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions") - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] - # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, : key_states.shape[1]] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - causal_mask, - tgt_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, None -class Qwen2AudioSdpaAttention(Qwen2AudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, None - - -QWEN2AUDIO_ATTENTION_CLASSES = { - "eager": Qwen2AudioAttention, - "flash_attention_2": Qwen2AudioFlashAttention2, - "sdpa": Qwen2AudioSdpaAttention, -} - - # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO class Qwen2AudioEncoderLayer(nn.Module): def __init__(self, config: Qwen2AudioConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = QWEN2AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Qwen2AudioAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index db362355b8..08552d4114 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -636,6 +636,10 @@ class WhisperGenerationMixin(GenerationMixin): # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation # where the input ids are handled explicitly by the generate method self._check_decoder_input_ids(kwargs=kwargs) + # `output_attentions` is deprecated - we force eager attention if this feature is + # indirectly requested, e.g. through return_token_timestamps + if return_token_timestamps: + self.model.config._attn_implementation = "eager" # 3. Retrieve logits processors device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ef847d6059..7bb07a6c1c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -15,7 +15,7 @@ """PyTorch Whisper model.""" import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -24,12 +24,11 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import ( - flash_attn_supports_top_left_mask, - is_flash_attn_available, + FlashAttentionKwargs, ) from ...modeling_outputs import ( BaseModelOutput, @@ -39,21 +38,13 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, logging from .configuration_whisper import WhisperConfig from .generation_whisper import WhisperGenerationMixin -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _HIDDEN_STATES_START_POSITION = 1 @@ -219,6 +210,36 @@ class WhisperPositionalEmbedding(nn.Embedding): return self.weight[position_ids] +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None and attention_mask.ndim == 4: + attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]] + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class WhisperAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -262,29 +283,36 @@ class WhisperAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - # get query proj + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + + # Scaling is susceptible to floating point arithmetics' inprecisions + # which can lead to different results (this is dependent from model + # to model, e.g. whisper is one such case). We therefore keep the + # original order of scaling to follow the original implementation + # and enforce no scaling (1.0) in the attention call below. query_states = self.q_proj(hidden_states) * self.scaling - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.view(*q_input_shape) query_states = query_states.transpose(1, 2).contiguous() if past_key_value is not None: @@ -314,278 +342,36 @@ class WhisperAttention(nn.Module): key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_probs, value_states) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=1.0, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, attn_weights, past_key_value -class WhisperFlashAttention2(WhisperAttention): - """ - Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " - "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) - - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache - else: - past_key_value = past_key_value.self_attention_cache - - # use key_value_states if cross attention - current_states = key_value_states if key_value_states is not None else hidden_states - if is_cross_attention and past_key_value and is_updated: - # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim) - value_states = self.v_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim) - key_states = key_states.transpose(1, 2).contiguous() - value_states = value_states.transpose(1, 2).contiguous() - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] - # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, : key_states.shape[1]] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - causal_mask, - tgt_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, tgt_len, -1) - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class WhisperSdpaAttention(WhisperAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).contiguous() - - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache - else: - past_key_value = past_key_value.self_attention_cache - - # use key_value_states if cross attention - current_states = key_value_states if key_value_states is not None else hidden_states - if is_cross_attention and past_key_value and is_updated: - # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) - value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) - key_states = key_states.transpose(1, 2).contiguous() - value_states = value_states.transpose(1, 2).contiguous() - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -WHISPER_ATTENTION_CLASSES = { - "eager": WhisperAttention, - "flash_attention_2": WhisperFlashAttention2, - "sdpa": WhisperSdpaAttention, -} - - -# (BC Dep) Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER -# TODO(vasqu): fix copies when enabling whisper attn interface +# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -653,7 +439,7 @@ class WhisperDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -667,7 +453,7 @@ class WhisperDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = WhisperAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -774,6 +560,7 @@ class WhisperPreTrainedModel(PreTrainedModel): _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True _supports_cache_class = True _supports_static_cache = True @@ -1142,12 +929,12 @@ class WhisperDecoder(WhisperPreTrainedModel): hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - past_key_values.self_attention_cache if past_key_values is not None else None, - output_attentions, + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, ) if self.gradient_checkpointing and self.training: @@ -1237,131 +1024,6 @@ class WhisperDecoder(WhisperPreTrainedModel): cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, "BlockMask"], - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return attention_mask - return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype = input_tensor.dtype - sequence_length = input_tensor.shape[1] - if using_compilable_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - @auto_docstring class WhisperModel(WhisperPreTrainedModel): diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 3d2e2c7d37..571ac07370 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -156,6 +156,10 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip(reason="Qwen2 Audio does not support right padding.") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): # overwrite because Qwen2 is audio+text model (not vision+text) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index ffc4b59abb..1397bbe4dc 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -31,6 +31,7 @@ from transformers import WhisperConfig from transformers.testing_utils import ( is_flaky, require_flash_attn, + require_read_token, require_torch, require_torch_accelerator, require_torch_fp16, @@ -542,8 +543,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip - def test_generate_with_head_masking(self): + @parameterized.expand([("offloaded",)]) + @pytest.mark.generate + @unittest.skip(reason="Whisper doesn't work with offloaded cache implementation yet") + def test_offloaded_cache_implementation(self, cache_implementation): pass @require_torch_fp16 @@ -660,6 +663,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + seq_len = getattr(self.model_tester, "seq_length", None) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1) encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) @@ -849,7 +855,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) - @unittest.skip + @unittest.skip(reason="Whisper encoder-decoder requires the features directly and can not work on ids only.") def test_generate_without_input_ids(self): pass @@ -1422,6 +1428,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi def test_generate_compilation_all_outputs(self): pass + # TODO (cyril): fix me :) + @unittest.skip(reason="Torchscript doesn't work with the new mask creation functions") + def test_torchscript_output_attentions(self): + pass + + # TODO (cyril): fix me :) + @unittest.skip(reason="Torchscript doesn't work with the new mask creation functions") + def test_torchscript_output_hidden_state(self): + pass + + # TODO (cyril): fix me :) + @unittest.skip(reason="Torchscript doesn't work with the new mask creation functions") + def test_torchscript_simple(self): + pass + @require_torch @require_torchaudio @@ -1684,6 +1705,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + @require_read_token @slow def test_large_batched_generation_multilingual(self): processor = WhisperProcessor.from_pretrained("openai/whisper-large") @@ -1775,7 +1797,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ]) # fmt: on - torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT) + torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT) EXPECTED_TRANSCRIPT = [ { @@ -2016,7 +2038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430 ]) # fmt: on - torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT) + torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT) EXPECTED_TRANSCRIPT = [ { @@ -3610,27 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, config=config, input_ids=inputs_dict["input_ids"] ) - @unittest.skip(reason="Tested implicitly through the encoder-decoder tests") - def test_custom_4d_attention_mask(self): - pass - - @unittest.skip(reason="Generate needs input ids") - def test_generate_without_input_ids(self): - # generate only works with input ids for whisper - pass - @unittest.skip(reason="Decoder can't keep attention grads") def test_retain_grad_hidden_states_attentions(self): return - @unittest.skip( - "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" - ) - def test_flash_attn_2_inference(self): - pass - - @unittest.skip( - "Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test" - ) - def test_flash_attn_2_inference_padding_right(self): - pass + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 512f6346e7..9ee6a93dbb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4268,24 +4268,28 @@ class ModelTesterMixin: if not model_class._supports_flash_attn_2: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() # TODO: to change it in the future with other relevant auto classes fa2_model = model_class._from_config( - config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 + config, attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).to(torch_device) - dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + dummy_input = inputs_dict[fa2_model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - if config.is_encoder_decoder: + if fa2_model.config.is_encoder_decoder: + dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] + dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] _ = fa2_model( - input_ids=dummy_input, + dummy_input, attention_mask=dummy_attention_mask, - decoder_input_ids=dummy_input.clone(), - decoder_attention_mask=dummy_attention_mask.clone(), + decoder_input_ids=dummy_decoder_input_ids, + decoder_attention_mask=dummy_decoder_attention_mask, ) else: - _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + _ = fa2_model(dummy_input, attention_mask=dummy_attention_mask) with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname)