🔴[Attention] Attention refactor for Whisper-based models (#38235)
* start refactoring whisper * revert for now * first step * carry over attn fixes * check if this works * whisper has an off by one somewhere - cutting mask in any interface * make it based on interface * remove some tests that were skipped but now work * some fixes for whisper tests * interface changes * change the order of fix * some attention adjustments for eager + TP * fix scaling * mask changes * why does whisper contain those extra seq lens? * fix from config for fa2 as input_ids is invalid * fix another test * another fix * disable flex attn due to compile issues * copies and refactor for qwen audio since it somewhat relies on whisper * fix scaling and smaller things * retrigger * new new interface version + more fixups * adjust qwen * add comment * forgot this one * change copies as whisper cuts on the mask * add guard * add flex attention * switch to new mask function + add skips for torchscript * remove old api with cache position * last changes? * trigger ci
This commit is contained in:
@@ -1716,6 +1716,19 @@ class EncoderDecoderCache(Cache):
|
||||
self.self_attention_cache.batch_select_indices(indices)
|
||||
self.cross_attention_cache.batch_select_indices(indices)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
|
||||
return self.self_attention_cache.get_max_cache_shape()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
|
||||
|
||||
|
||||
class HybridCache(Cache):
|
||||
"""
|
||||
|
||||
@@ -2051,6 +2051,10 @@ class WhisperNoSpeechDetection(LogitsProcessor):
|
||||
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
|
||||
self.inputs["input_features"] = self.inputs.pop("inputs")
|
||||
|
||||
# Whisper encoder-decoder does not accept the input_ids as input
|
||||
if "input_ids" not in inspect.signature(self.model.forward).parameters:
|
||||
self.inputs.pop("input_ids", None)
|
||||
|
||||
@property
|
||||
def no_speech_prob(self):
|
||||
return self._no_speech_prob
|
||||
|
||||
@@ -636,7 +636,7 @@ class GenerationMixin(ContinuousMixin):
|
||||
and attention_mask is not None
|
||||
and attention_mask.ndim == 2
|
||||
):
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
if not self.config.is_encoder_decoder and model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@@ -25,18 +25,13 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig
|
||||
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -82,6 +77,37 @@ class Qwen2AudioCausalLMOutputWithPast(ModelOutput):
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Qwen2AudioAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -135,210 +161,42 @@ class Qwen2AudioAttention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
# Scaling is susceptible to floating point arithmetics' inprecisions
|
||||
# which can lead to different results (this is dependent from model
|
||||
# to model, e.g. whisper is one such case). We therefore keep the
|
||||
# original order of scaling to follow the original implementation
|
||||
# and enforce no scaling (1.0) in the attention call below.
|
||||
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
||||
"""
|
||||
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2.__init__ with Whisper->Qwen2Audio
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Qwen2AudioFlashAttention2 attention does not support output_attentions
|
||||
if output_attentions:
|
||||
raise ValueError("Qwen2AudioFlashAttention2 attention does not support output_attentions")
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, : key_states.shape[1]]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal_mask,
|
||||
tgt_len,
|
||||
dropout=self.dropout if self.training else 0.0,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"Qwen2AudioModel is using Qwen2AudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
output_attentions=output_attentions,
|
||||
head_mask=layer_head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, None
|
||||
|
||||
|
||||
QWEN2AUDIO_ATTENTION_CLASSES = {
|
||||
"eager": Qwen2AudioAttention,
|
||||
"flash_attention_2": Qwen2AudioFlashAttention2,
|
||||
"sdpa": Qwen2AudioSdpaAttention,
|
||||
}
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO
|
||||
@@ -347,7 +205,7 @@ class Qwen2AudioEncoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = QWEN2AUDIO_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = Qwen2AudioAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
|
||||
@@ -636,6 +636,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
|
||||
# where the input ids are handled explicitly by the generate method
|
||||
self._check_decoder_input_ids(kwargs=kwargs)
|
||||
# `output_attentions` is deprecated - we force eager attention if this feature is
|
||||
# indirectly requested, e.g. through return_token_timestamps
|
||||
if return_token_timestamps:
|
||||
self.model.config._attn_implementation = "eager"
|
||||
|
||||
# 3. Retrieve logits processors
|
||||
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""PyTorch Whisper model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,12 +24,11 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import (
|
||||
flash_attn_supports_top_left_mask,
|
||||
is_flash_attn_available,
|
||||
FlashAttentionKwargs,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -39,21 +38,13 @@ from ...modeling_outputs import (
|
||||
Seq2SeqModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
from .configuration_whisper import WhisperConfig
|
||||
from .generation_whisper import WhisperGenerationMixin
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
if is_flash_attn_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
@@ -219,6 +210,36 @@ class WhisperPositionalEmbedding(nn.Embedding):
|
||||
return self.weight[position_ids]
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class WhisperAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@@ -262,29 +283,36 @@ class WhisperAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
# determine input shapes
|
||||
bsz, tgt_len = hidden_states.shape[:-1]
|
||||
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
|
||||
|
||||
# Scaling is susceptible to floating point arithmetics' inprecisions
|
||||
# which can lead to different results (this is dependent from model
|
||||
# to model, e.g. whisper is one such case). We therefore keep the
|
||||
# original order of scaling to follow the original implementation
|
||||
# and enforce no scaling (1.0) in the attention call below.
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.view(*q_input_shape)
|
||||
query_states = query_states.transpose(1, 2).contiguous()
|
||||
|
||||
if past_key_value is not None:
|
||||
@@ -314,278 +342,36 @@ class WhisperAttention(nn.Module):
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class WhisperFlashAttention2(WhisperAttention):
|
||||
"""
|
||||
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
|
||||
def forward(
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
||||
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, : key_states.shape[1]]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
causal_mask,
|
||||
tgt_len,
|
||||
dropout=self.dropout if self.training else 0.0,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class WhisperSdpaAttention(WhisperAttention):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
|
||||
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
key_value_states=key_value_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
head_mask=layer_head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.transpose(1, 2).contiguous()
|
||||
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
|
||||
# use key_value_states if cross attention
|
||||
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||
if is_cross_attention and past_key_value and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value.key_cache[self.layer_idx]
|
||||
value_states = past_key_value.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.transpose(1, 2).contiguous()
|
||||
value_states = value_states.transpose(1, 2).contiguous()
|
||||
if past_key_value is not None:
|
||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
|
||||
|
||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
WHISPER_ATTENTION_CLASSES = {
|
||||
"eager": WhisperAttention,
|
||||
"flash_attention_2": WhisperFlashAttention2,
|
||||
"sdpa": WhisperSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# (BC Dep) Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
|
||||
# TODO(vasqu): fix copies when enabling whisper attn interface
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
|
||||
class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@@ -653,7 +439,7 @@ class WhisperDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@@ -667,7 +453,7 @@ class WhisperDecoderLayer(nn.Module):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.encoder_attn = WhisperAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
dropout=config.attention_dropout,
|
||||
@@ -774,6 +560,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
@@ -1142,12 +929,12 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values.self_attention_cache if past_key_values is not None else None,
|
||||
output_attentions,
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
@@ -1237,131 +1024,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, "BlockMask"],
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
if self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask)
|
||||
return attention_mask
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_compilable_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type in ["cuda", "xpu", "npu"]
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class WhisperModel(WhisperPreTrainedModel):
|
||||
|
||||
@@ -156,6 +156,10 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Qwen2 Audio does not support right padding.")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
# overwrite because Qwen2 is audio+text model (not vision+text)
|
||||
|
||||
@@ -31,6 +31,7 @@ from transformers import WhisperConfig
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
@@ -542,8 +543,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip
|
||||
def test_generate_with_head_masking(self):
|
||||
@parameterized.expand([("offloaded",)])
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Whisper doesn't work with offloaded cache implementation yet")
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
@@ -660,6 +663,9 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
@@ -849,7 +855,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@unittest.skip
|
||||
@unittest.skip(reason="Whisper encoder-decoder requires the features directly and can not work on ids only.")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@@ -1422,6 +1428,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
# TODO (cyril): fix me :)
|
||||
@unittest.skip(reason="Torchscript doesn't work with the new mask creation functions")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
# TODO (cyril): fix me :)
|
||||
@unittest.skip(reason="Torchscript doesn't work with the new mask creation functions")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
# TODO (cyril): fix me :)
|
||||
@unittest.skip(reason="Torchscript doesn't work with the new mask creation functions")
|
||||
def test_torchscript_simple(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@@ -1684,6 +1705,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@require_read_token
|
||||
@slow
|
||||
def test_large_batched_generation_multilingual(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
@@ -1775,7 +1797,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@@ -2016,7 +2038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50365, 293, 393, 4411, 50430
|
||||
])
|
||||
# fmt: on
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@@ -3610,27 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
config=config, input_ids=inputs_dict["input_ids"]
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Generate needs input ids")
|
||||
def test_generate_without_input_ids(self):
|
||||
# generate only works with input ids for whisper
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Decoder can't keep attention grads")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(
|
||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||
)
|
||||
def test_flash_attn_2_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||
)
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -4268,24 +4268,28 @@ class ModelTesterMixin:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# TODO: to change it in the future with other relevant auto classes
|
||||
fa2_model = model_class._from_config(
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
|
||||
config, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||
dummy_input = inputs_dict[fa2_model.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
if fa2_model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
_ = fa2_model(
|
||||
input_ids=dummy_input,
|
||||
dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
decoder_input_ids=dummy_input.clone(),
|
||||
decoder_attention_mask=dummy_attention_mask.clone(),
|
||||
decoder_input_ids=dummy_decoder_input_ids,
|
||||
decoder_attention_mask=dummy_decoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
_ = fa2_model(dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
|
||||
Reference in New Issue
Block a user