diff --git a/docs/source/en/model_doc/biogpt.md b/docs/source/en/model_doc/biogpt.md index d7145993a8..11ab89c9f2 100644 --- a/docs/source/en/model_doc/biogpt.md +++ b/docs/source/en/model_doc/biogpt.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
@@ -40,13 +41,13 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The ### Using Scaled Dot Product Attention (SDPA) -PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function -encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the -[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) page for more information. -SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. ``` @@ -109,7 +110,7 @@ we saw the following speedups during inference. [[autodoc]] BioGptForCausalLM - forward - + ## BioGptForTokenClassification [[autodoc]] BioGptForTokenClassification diff --git a/docs/source/en/model_doc/blenderbot-small.md b/docs/source/en/model_doc/blenderbot-small.md index 647a865de3..341e43c030 100644 --- a/docs/source/en/model_doc/blenderbot-small.md +++ b/docs/source/en/model_doc/blenderbot-small.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA Note that [`BlenderbotSmallModel`] and @@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI). ## Usage tips -Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than +Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. diff --git a/docs/source/en/model_doc/blenderbot.md b/docs/source/en/model_doc/blenderbot.md index ec24d5ed74..adfa6841e1 100644 --- a/docs/source/en/model_doc/blenderbot.md +++ b/docs/source/en/model_doc/blenderbot.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview @@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The ## Usage tips and example -Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right +Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. An example: @@ -71,7 +73,7 @@ An example: `facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with [BlenderbotSmall](blenderbot-small). - + ## Resources - [Causal language modeling task guide](../tasks/language_modeling) diff --git a/docs/source/en/model_doc/marian.md b/docs/source/en/model_doc/marian.md index 80bb73d26d..4fcd636355 100644 --- a/docs/source/en/model_doc/marian.md +++ b/docs/source/en/model_doc/marian.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview @@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char >>> model = MarianMTModel.from_pretrained(model_name) >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) >>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] -["c'est une phrase en anglais que nous voulons traduire en français", +["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español'] ``` diff --git a/docs/source/en/model_doc/nllb-moe.md b/docs/source/en/model_doc/nllb-moe.md index 65a4812ed6..fc8c8c9211 100644 --- a/docs/source/en/model_doc/nllb-moe.md +++ b/docs/source/en/model_doc/nllb-moe.md @@ -51,10 +51,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse ## Implementation differences with SwitchTransformers -The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the -highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, -which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden -states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. +The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the +highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, +which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden +states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. ## Generating with NLLB-MoE diff --git a/docs/source/en/model_doc/pegasus.md b/docs/source/en/model_doc/pegasus.md index bdb61e66d9..5681ac9b58 100644 --- a/docs/source/en/model_doc/pegasus.md +++ b/docs/source/en/model_doc/pegasus.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview diff --git a/docs/source/en/model_doc/pegasus_x.md b/docs/source/en/model_doc/pegasus_x.md index 3f982263cd..97e50601b7 100644 --- a/docs/source/en/model_doc/pegasus_x.md +++ b/docs/source/en/model_doc/pegasus_x.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention
## Overview diff --git a/docs/source/en/model_doc/plbart.md b/docs/source/en/model_doc/plbart.md index bac567615d..d57ee8ed99 100644 --- a/docs/source/en/model_doc/plbart.md +++ b/docs/source/en/model_doc/plbart.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview @@ -29,7 +31,7 @@ on Java, Python and English. According to the abstract *Code summarization and generation empower conversion between programming language (PL) and natural language (NL), -while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, +while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks. PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. Experiments on code summarization in the English language, code generation, and code translation in seven programming languages @@ -50,7 +52,7 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used. However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this. -In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format +In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if it's passed with the `text_target` keyword argument. diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index a78166ed04..4f76e65a84 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -3,8 +3,11 @@ from typing import Optional, Tuple import torch from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ..utils import logging +logger = logging.get_logger(__name__) + _use_top_left_mask = flash_attn_supports_top_left_mask() @@ -20,6 +23,12 @@ def flash_attention_forward( softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: + logger.warning_once( + "`flash_attention_2` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + # This is before the transpose seq_len = query.shape[2] diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index afdaba5199..cc9787657b 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -31,13 +31,15 @@ from typing import Optional, Tuple, Union import torch from packaging import version -from ..utils import is_torch_flex_attn_available +from ..utils import is_torch_flex_attn_available, logging from ..utils.import_utils import _torch_version, is_torchdynamo_compiling if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask, flex_attention - from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + + +logger = logging.get_logger(__name__) class WrappedFlexAttention: @@ -98,21 +100,23 @@ def compile_friendly_flex_attention( Offset = Union[torch.Tensor, int] +# TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore def make_flex_block_causal_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, query_length=None, key_length=None, offsets: Optional[Tuple[Offset, Offset]] = None, + is_causal: Optional[bool] = True, ) -> "BlockMask": """ IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`, and will be removed in a future version without warnings. New code should not use it. It is only kept here for BC for now, while models using it are being patched accordingly. - Create a block causal document mask for a batch of sequences, both packed and unpacked. - Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal + Create a block (causal) document mask for a batch of sequences, both packed and unpacked. + Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full (causal) block mask. BlockMask is essential for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ @@ -170,7 +174,21 @@ def make_flex_block_causal_mask( causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) return chunk_mask & causal_doc_mask - mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod + def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Utilizes default attention mask to enable encoder and encoder-decoder + attention masks. + """ + document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + # kv indexing is crucial in order to work correctly + padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0 + final_mask = padding_mask & document_mask + return final_mask + + if not is_causal: + mask_mod_maybe_combined = default_mask_mod + else: + mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod if offsets is not None: q_offset = offsets[0] @@ -182,7 +200,8 @@ def make_flex_block_causal_mask( return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) else: mask_mod = mask_mod_maybe_combined - return create_block_causal_mask_flex( + + return create_block_mask( mask_mod=mask_mod, B=batch_size, H=None, # attention head @@ -216,21 +235,33 @@ def flex_attention_forward( head_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if kwargs.get("output_attentions", False) or head_mask is not None: + logger.warning_once( + "`flex_attention` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + + if kwargs.get("dropout", 0.0) > 0: + raise ValueError( + "`flex_attention` does not support `dropout`. Please use it with inference" + " only (`model.eval()`) or turn off the attention dropout in the respective config." + ) + block_mask = None - causal_mask = None + score_mask = None if isinstance(attention_mask, BlockMask): block_mask = attention_mask else: - causal_mask = attention_mask + score_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if score_mask is not None: + score_mask = score_mask[:, :, :, : key.shape[-2]] def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): if softcap is not None: score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][0][q_idx][kv_idx] + if score_mask is not None: + score = score + score_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: score = score + head_mask[batch_idx][head_idx][0][0] return score diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 9c924c048a..247cd28216 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -2,6 +2,11 @@ from typing import Optional, Tuple import torch +from ..utils import logging + + +logger = logging.get_logger(__name__) + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -26,13 +31,18 @@ def sdpa_attention_forward( is_causal: Optional[bool] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: + logger.warning_once( + "`sdpa` attention does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." + ) + if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -44,7 +54,9 @@ def sdpa_attention_forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None + # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag + # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns + is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. @@ -55,7 +67,7 @@ def sdpa_attention_forward( query, key, value, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=dropout, scale=scaling, is_causal=is_causal, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 97e95b4161..0ad3947815 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -165,6 +165,7 @@ if is_safetensors_available(): if is_kernels_available(): from kernels import get_kernel + logger = logging.get_logger(__name__) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 9f7f1515a2..0a41692f69 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,14 +26,21 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_autoformer import AutoformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -904,6 +911,29 @@ class AutoformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoder with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerEncoder(AutoformerPreTrainedModel): @@ -983,10 +1013,10 @@ class AutoformerEncoder(AutoformerPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 01f7f19a79..60d9cdba2a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -32,7 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,7 +42,8 @@ from ...modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -53,13 +54,7 @@ from .configuration_bart import BartConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -119,6 +114,36 @@ class BartScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -170,17 +195,25 @@ class BartAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -201,8 +234,8 @@ class BartAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -214,297 +247,27 @@ class BartAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class BartFlashAttention2(BartAttention): - """ - Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # BartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "BartSdpaAttention2 attention does not support `output_attentions`. " - "Use the argument `attn_implementation='eager'` when loading the model." - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -class BartSdpaAttention(BartAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -BART_ATTENTION_CLASSES = { - "eager": BartAttention, - "sdpa": BartSdpaAttention, - "flash_attention_2": BartFlashAttention2, -} + return attn_output, attn_weights, past_key_value class BartEncoderLayer(nn.Module): @@ -512,7 +275,7 @@ class BartEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -583,7 +346,7 @@ class BartDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -597,7 +360,7 @@ class BartDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -671,6 +434,7 @@ class BartDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -730,6 +494,8 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -757,23 +523,53 @@ class BartPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -782,7 +578,7 @@ class BartPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -816,7 +612,6 @@ class BartPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -882,6 +677,41 @@ class BartPreTrainedModel(PreTrainedModel): return causal_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PretrainedBartModel(BartPreTrainedModel): def __init_subclass__(self): @@ -932,8 +762,6 @@ class BartEncoder(BartPreTrainedModel): embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -1019,18 +847,10 @@ class BartEncoder(BartPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1116,8 +936,6 @@ class BartDecoder(BartPreTrainedModel): config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -1232,12 +1050,18 @@ class BartDecoder(BartPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` return_legacy_cache = False @@ -1267,38 +1091,25 @@ class BartDecoder(BartPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) - # embed positions - position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) - position_ids = position_ids.to(inputs_embeds.device) + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1331,7 +1142,7 @@ class BartDecoder(BartPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1344,7 +1155,7 @@ class BartDecoder(BartPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4ff34b9ef2..d49d4e65bd 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -29,7 +29,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,20 +41,14 @@ from ...modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_bigbird_pegasus import BigBirdPegasusConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -1179,6 +1175,37 @@ class BigBirdPegasusEncoderAttention(nn.Module): return outputs +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder class BigBirdPegasusDecoderAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1231,17 +1258,25 @@ class BigBirdPegasusDecoderAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -1262,8 +1297,8 @@ class BigBirdPegasusDecoderAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -1275,66 +1310,27 @@ class BigBirdPegasusDecoderAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class BigBirdPegasusEncoderLayer(nn.Module): @@ -1434,6 +1430,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, layer_idx=layer_idx, ) self.dropout = config.dropout @@ -1447,6 +1444,7 @@ class BigBirdPegasusDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -1510,7 +1508,6 @@ class BigBirdPegasusDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -1602,23 +1599,32 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -1627,7 +1633,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1661,7 +1667,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -1727,6 +1732,42 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): """ @@ -2172,9 +2213,13 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -2207,28 +2252,26 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions - position_ids = cache_position.unsqueeze(0) - position_ids = self.embed_positions( - (batch_size, seq_length), past_key_values_length, position_ids=position_ids - ) - position_ids = position_ids.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers @@ -2258,7 +2301,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2271,7 +2314,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -2979,7 +3022,7 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d93b6f6ae2..0b2a0dc274 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_biogpt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. # @@ -12,56 +18,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch BioGPT model.""" import math -from typing import Optional, Tuple, Union +from functools import partial +from typing import Callable, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging from .configuration_biogpt import BioGptConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) -# copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt -# TODO @ArthurZucker bring copied from back class BioGptLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, num_embeddings: int, embedding_dim: int): - # BioGpt is set up so that if padding_idx is specified then offset the embedding ids by 2 + # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) @@ -70,22 +66,19 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding): self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, - position_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, ): """`input_ids_shape` is expected to be [bsz x seqlen].""" + if position_ids is None: - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 - + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() # cut positions if `past_key_values_length` is > 0 - position_ids = positions[:, past_key_values_length:] + position_ids = position_ids[:, past_key_values_length:] return super().forward(position_ids + self.offset) -# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt class BioGptScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. @@ -99,7 +92,36 @@ class BioGptScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BioGptAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -151,17 +173,25 @@ class BioGptAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -182,8 +212,8 @@ class BioGptAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -195,178 +225,27 @@ class BioGptAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->BioGpt -class BioGptSdpaAttention(BioGptAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -BIOGPT_ATTENTION_CLASSES = { - "eager": BioGptAttention, - "sdpa": BioGptSdpaAttention, -} + return attn_output, attn_weights, past_key_value class BioGptDecoderLayer(nn.Module): @@ -374,12 +253,13 @@ class BioGptDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = BIOGPT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BioGptAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_probs_dropout_prob, is_decoder=True, is_causal=True, + config=config, layer_idx=layer_idx, ) self.dropout = config.hidden_dropout_prob @@ -400,7 +280,9 @@ class BioGptDecoderLayer(nn.Module): past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -431,7 +313,9 @@ class BioGptDecoderLayer(nn.Module): attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + position_ids=position_ids, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -462,7 +346,10 @@ class BioGptPreTrainedModel(PreTrainedModel): config_class = BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -482,23 +369,32 @@ class BioGptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -507,7 +403,7 @@ class BioGptPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -541,7 +437,6 @@ class BioGptPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -628,7 +523,6 @@ class BioGptModel(BioGptPreTrainedModel): self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False - self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() @@ -652,7 +546,7 @@ class BioGptModel(BioGptPreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, - **kwargs, # NOOP kwargs, for now + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -663,18 +557,24 @@ class BioGptModel(BioGptPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False @@ -696,7 +596,7 @@ class BioGptModel(BioGptPreTrainedModel): past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - if attention_mask is None and not is_torchdynamo_compiling(): + if attention_mask is None: # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) @@ -706,27 +606,37 @@ class BioGptModel(BioGptPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) # embed positions if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] - position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) - - hidden_states = inputs_embeds + position_ids + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = None + next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -739,13 +649,14 @@ class BioGptModel(BioGptPreTrainedModel): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, + position_ids, cache_position, ) else: @@ -756,7 +667,9 @@ class BioGptModel(BioGptPreTrainedModel): past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + position_ids=position_ids, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -792,6 +705,9 @@ class BioGptModel(BioGptPreTrainedModel): ) +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + @auto_docstring( custom_intro=""" BioGPT Model with a `language modeling` head on top for CLM fine-tuning. @@ -830,7 +746,7 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -852,6 +768,7 @@ class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) sequence_output = outputs[0] @@ -916,9 +833,11 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -935,9 +854,11 @@ class BioGptForTokenClassification(BioGptPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -1004,9 +925,11 @@ class BioGptForSequenceClassification(BioGptPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1023,9 +946,11 @@ class BioGptForSequenceClassification(BioGptPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py new file mode 100644 index 0000000000..4bd675be92 --- /dev/null +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -0,0 +1,851 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BioGPT model.""" + +import math +from functools import partial +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + auto_docstring, + is_torch_flex_attn_available, + logger, +) +from ..bart.modeling_bart import ( + BartAttention, + BartDecoderLayer, + BartScaledWordEmbedding, +) +from ..opt.modeling_opt import OPTLearnedPositionalEmbedding +from .configuration_biogpt import BioGptConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + super().forward(attention_mask, past_key_values_length, position_ids) + + +class BioGptScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +class BioGptAttention(BartAttention): + pass + + +class BioGptDecoderLayer(BartDecoderLayer): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): + super().__init__(config) + self.embed_dim = config.hidden_size + + self.self_attn = BioGptAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_probs_dropout_prob, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.hidden_dropout_prob + self.activation_fn = ACT2FN[config.hidden_act] + + self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) + + del self.encoder_attn + del self.encoder_attn_layer_norm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +@auto_docstring +class BioGptPreTrainedModel(PreTrainedModel): + config_class = BioGptConfig + base_model_prefix = "biogpt" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@auto_docstring +class BioGptModel(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.config = config + self.layerdrop = config.layerdrop + self.dropout = config.hidden_dropout_prob + self.embed_dim = config.hidden_size + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) + + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize past_key_values + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None: + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + ) + + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + position_ids, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.layer_norm(hidden_states) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring( + custom_intro=""" + BioGPT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + super().__init__(config) + + self.biogpt = BioGptModel(config) + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.biogpt( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.output_projection(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@auto_docstring +class BioGptForTokenClassification(BioGptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.biogpt = BioGptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + else: + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value + + +__all__ = [ + "BioGptForCausalLM", + "BioGptForTokenClassification", + "BioGptForSequenceClassification", + "BioGptModel", + "BioGptPreTrainedModel", +] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8eb282ac6f..da7282d388 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -18,7 +18,7 @@ import copy import math import os import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +31,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,7 +41,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -51,9 +54,7 @@ from .configuration_blenderbot import BlenderbotConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -110,6 +111,37 @@ class BlenderbotScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot class BlenderbotAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -162,17 +194,25 @@ class BlenderbotAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -193,8 +233,8 @@ class BlenderbotAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -206,69 +246,27 @@ class BlenderbotAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -277,7 +275,7 @@ class BlenderbotEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -346,7 +344,7 @@ class BlenderbotDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -360,7 +358,7 @@ class BlenderbotDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -428,7 +426,6 @@ class BlenderbotDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -465,6 +462,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -493,23 +494,55 @@ class BlenderbotPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -518,7 +551,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -552,7 +585,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -618,6 +650,42 @@ class BlenderbotPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotEncoder(BlenderbotPreTrainedModel): """ @@ -730,10 +798,10 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -927,22 +995,28 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - ## retrieve input_ids and inputs_embeds + # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -972,20 +1046,19 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 2f778d7293..2237907aa0 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -29,7 +29,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,7 +39,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -48,9 +51,7 @@ from .configuration_blenderbot_small import BlenderbotSmallConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -94,6 +95,37 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding): return super().forward(position_ids) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall class BlenderbotSmallAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -146,17 +178,25 @@ class BlenderbotSmallAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -177,8 +217,8 @@ class BlenderbotSmallAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -190,66 +230,27 @@ class BlenderbotSmallAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -258,7 +259,7 @@ class BlenderbotSmallEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -324,19 +325,13 @@ class BlenderbotSmallEncoderLayer(nn.Module): return outputs -# TODO: Implement attention with SDPA for TimeSeriesTransformer. -BLENDERBOT_SMALL_ATTENTION_CLASSES = { - "eager": BlenderbotSmallAttention, -} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -350,7 +345,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotSmallAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -424,6 +419,7 @@ class BlenderbotSmallDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -454,6 +450,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -482,23 +482,55 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -507,7 +539,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -541,7 +573,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -607,6 +638,42 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): """ @@ -718,10 +785,10 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -909,24 +976,28 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) - inputs_embeds = inputs_embeds * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -956,20 +1027,19 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index e1a822ea03..d9046ea6e8 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -4,9 +4,24 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_data2vec_audio.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -16,7 +31,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,16 +41,14 @@ from ...modeling_outputs import ( Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_peft_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available from .configuration_data2vec_audio import Data2VecAudioConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask class Data2VecAudioConvLayer(nn.Module): @@ -167,6 +181,36 @@ class Data2VecAudioFeatureProjection(nn.Module): return hidden_states, norm_hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Data2VecAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -201,9 +245,6 @@ class Data2VecAudioAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -212,6 +253,9 @@ class Data2VecAudioAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -219,10 +263,16 @@ class Data2VecAudioAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -237,18 +287,18 @@ class Data2VecAudioAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -260,298 +310,29 @@ class Data2VecAudioAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class Data2VecAudioFlashAttention2(Data2VecAudioAttention): - """ - Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Data2VecAudioSdpaAttention(Data2VecAudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class Data2VecAudioFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -576,21 +357,15 @@ class Data2VecAudioFeedForward(nn.Module): return hidden_states -DATA2VEC_AUDIO_ATTENTION_CLASSES = { - "eager": Data2VecAudioAttention, - "sdpa": Data2VecAudioSdpaAttention, - "flash_attention_2": Data2VecAudioFlashAttention2, -} - - class Data2VecAudioEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Data2VecAudioAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -627,7 +402,6 @@ class Data2VecAudioEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -644,16 +418,11 @@ class Data2VecAudioEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -702,6 +471,28 @@ class Data2VecAudioEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Data2VecAudioAdapterLayer(nn.Module): def __init__(self, config): @@ -760,6 +551,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 58934d2e86..73a42937bd 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecText model.""" + import math import torch @@ -124,6 +140,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index d920e998f9..eb366963a6 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -4,8 +4,23 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_hubert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -15,15 +30,17 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_hubert import HubertConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -224,6 +241,36 @@ class HubertFeatureProjection(nn.Module): return hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class HubertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -258,9 +305,6 @@ class HubertAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -269,6 +313,9 @@ class HubertAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -276,10 +323,16 @@ class HubertAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -294,18 +347,18 @@ class HubertAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -317,298 +370,29 @@ class HubertAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class HubertFlashAttention2(HubertAttention): - """ - Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class HubertSdpaAttention(HubertAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class HubertFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -633,21 +417,15 @@ class HubertFeedForward(nn.Module): return hidden_states -HUBERT_ATTENTION_CLASSES = { - "eager": HubertAttention, - "sdpa": HubertSdpaAttention, - "flash_attention_2": HubertFlashAttention2, -} - - class HubertEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -684,7 +462,6 @@ class HubertEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -701,16 +478,11 @@ class HubertEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -759,6 +531,28 @@ class HubertEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class HubertAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -788,11 +582,12 @@ class HubertAttnAdapterLayer(nn.Module): class HubertEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -841,7 +636,6 @@ class HubertEncoderStableLayerNorm(nn.Module): [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -855,19 +649,14 @@ class HubertEncoderStableLayerNorm(nn.Module): all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -918,6 +707,28 @@ class HubertEncoderStableLayerNorm(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + @auto_docstring class HubertPreTrainedModel(PreTrainedModel): @@ -927,6 +738,8 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index b3e3d24cc0..75000c95cb 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Hubert model.""" + from typing import Optional, Tuple, Union import torch @@ -115,6 +131,8 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 8b728a19df..330bc620bc 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/informer/modular_informer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_informer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. # @@ -12,19 +18,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Informer model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -32,19 +41,20 @@ from ...modeling_outputs import ( Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_informer import InformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesFeatureEmbedder with TimeSeries->Informer class InformerFeatureEmbedder(nn.Module): """ Embed a sequence of categorical features. @@ -79,7 +89,6 @@ class InformerFeatureEmbedder(nn.Module): ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerStdScaler(nn.Module): """ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by @@ -115,7 +124,6 @@ class InformerStdScaler(nn.Module): return (data - loc) / scale, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerMeanScaler(nn.Module): """ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data @@ -170,7 +178,6 @@ class InformerMeanScaler(nn.Module): return scaled_data, torch.zeros_like(scale), scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerNOPScaler(nn.Module): """ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. @@ -198,40 +205,6 @@ class InformerNOPScaler(nn.Module): return data, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average -def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: - """ - Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, - meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. - - Args: - input_tensor (`torch.FloatTensor`): - Input tensor, of which the average must be computed. - weights (`torch.FloatTensor`, *optional*): - Weights tensor, of the same shape as `input_tensor`. - dim (`int`, *optional*): - The dim along which to average `input_tensor`. - - Returns: - `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. - """ - if weights is not None: - weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) - sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) - return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights - else: - return input_tensor.mean(dim=dim) - - -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll -def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: - """ - Computes the negative log likelihood loss from input distribution with respect to target. - """ - return -input.log_prob(target) - - -# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer class InformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -266,7 +239,6 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding): return super().forward(position_ids) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Info class InformerValueEmbedding(nn.Module): def __init__(self, feature_size, d_model): super().__init__() @@ -276,7 +248,156 @@ class InformerValueEmbedding(nn.Module): return self.value_projection(x) -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Informer +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class InformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -289,6 +410,7 @@ class InformerAttention(nn.Module): bias: bool = True, is_causal: bool = False, config: Optional[InformerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -305,23 +427,31 @@ class InformerAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -329,110 +459,69 @@ class InformerAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class InformerProbSparseAttention(nn.Module): @@ -448,6 +537,7 @@ class InformerProbSparseAttention(nn.Module): is_decoder: bool = False, sampling_factor: int = 5, bias: bool = True, + layer_idx: Optional[int] = None, ): super().__init__() self.factor = sampling_factor @@ -463,6 +553,7 @@ class InformerProbSparseAttention(nn.Module): ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -480,6 +571,7 @@ class InformerProbSparseAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -488,45 +580,43 @@ class InformerProbSparseAttention(nn.Module): is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -681,6 +771,14 @@ class InformerEncoderLayer(nn.Module): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": self.self_attn = InformerProbSparseAttention( embed_dim=self.embed_dim, @@ -693,14 +791,8 @@ class InformerEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, @@ -754,25 +846,9 @@ class InformerEncoderLayer(nn.Module): class InformerDecoderLayer(nn.Module): - def __init__(self, config: InformerConfig): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - - if config.attention_type == "prob": - self.self_attn = InformerProbSparseAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - sampling_factor=config.sampling_factor, - is_decoder=True, - ) - else: - self.self_attn = InformerAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -783,12 +859,33 @@ class InformerDecoderLayer(nn.Module): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + def forward( self, hidden_states: torch.Tensor, @@ -797,9 +894,10 @@ class InformerDecoderLayer(nn.Module): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -818,47 +916,43 @@ class InformerDecoderLayer(nn.Module): output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -874,36 +968,15 @@ class InformerDecoderLayer(nn.Module): outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs -@auto_docstring -class InformerPreTrainedModel(PreTrainedModel): - config_class = InformerConfig - base_model_prefix = "model" - main_input_name = "past_values" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, InformerSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class InformerEncoder(InformerPreTrainedModel): """ - Informer encoder consisting of *config.encoder_layers* self attention layers with distillation layers. Each - attention layer is an [`InformerEncoderLayer`]. + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`InformerEncoderLayer`]. Args: config: InformerConfig @@ -914,7 +987,6 @@ class InformerEncoder(InformerPreTrainedModel): self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop - self.gradient_checkpointing = False if config.prediction_length is None: raise ValueError("The `prediction_length` config needs to be specified.") @@ -924,6 +996,7 @@ class InformerEncoder(InformerPreTrainedModel): ) self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False if config.distil: self.conv_layers = nn.ModuleList( @@ -932,7 +1005,6 @@ class InformerEncoder(InformerPreTrainedModel): self.conv_layers.append(None) else: self.conv_layers = [None] * config.encoder_layers - # Initialize weights and apply final processing self.post_init() @@ -1053,7 +1125,7 @@ class InformerEncoder(InformerPreTrainedModel): class InformerDecoder(InformerPreTrainedModel): """ - Informer decoder consisting of *config.decoder_layers* layers. Each layer is a + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`InformerDecoderLayer`] Args: @@ -1071,7 +1143,7 @@ class InformerDecoder(InformerPreTrainedModel): self.embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1091,6 +1163,7 @@ class InformerDecoder(InformerPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1148,6 +1221,9 @@ class InformerDecoder(InformerPreTrainedModel): for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1157,20 +1233,35 @@ class InformerDecoder(InformerPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = inputs_embeds.size()[:-1] - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) + + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) @@ -1188,7 +1279,7 @@ class InformerDecoder(InformerPreTrainedModel): all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1208,8 +1299,6 @@ class InformerDecoder(InformerPreTrainedModel): if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1222,6 +1311,7 @@ class InformerDecoder(InformerPreTrainedModel): None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1233,14 +1323,15 @@ class InformerDecoder(InformerPreTrainedModel): cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1253,6 +1344,9 @@ class InformerDecoder(InformerPreTrainedModel): all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1269,7 +1363,6 @@ class InformerDecoder(InformerPreTrainedModel): @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerModel with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer,TimeSeries->Informer class InformerModel(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) @@ -1408,7 +1501,6 @@ class InformerModel(InformerPreTrainedModel): def get_decoder(self): return self.decoder - # Ignore copy @auto_docstring def forward( self, @@ -1429,6 +1521,7 @@ class InformerModel(InformerPreTrainedModel): output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1586,7 +1679,16 @@ class InformerModel(InformerPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - dec_input = transformer_inputs[:, self.config.context_length :, ...] + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, @@ -1598,6 +1700,7 @@ class InformerModel(InformerPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1618,11 +1721,42 @@ class InformerModel(InformerPreTrainedModel): ) +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerForPrediction with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer class InformerForPrediction(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) + self.model = InformerModel(config) if config.distribution_output == "student_t": self.distribution_output = StudentTOutput(dim=config.input_size) @@ -1660,7 +1794,6 @@ class InformerForPrediction(InformerPreTrainedModel): sliced_params = [p[:, -trailing_n:] for p in params] return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) - # Ignore copy @auto_docstring def forward( self, @@ -1682,6 +1815,7 @@ class InformerForPrediction(InformerPreTrainedModel): output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1853,6 +1987,7 @@ class InformerForPrediction(InformerPreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) prediction_loss = None diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py new file mode 100644 index 0000000000..15bcb8d38a --- /dev/null +++ b/src/transformers/models/informer/modular_informer.py @@ -0,0 +1,997 @@ +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Informer model.""" + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from ...cache_utils import EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, +) +from ..bart.modeling_bart import BartAttention +from ..time_series_transformer.modeling_time_series_transformer import ( + TimeSeriesFeatureEmbedder, + TimeSeriesMeanScaler, + TimeSeriesNOPScaler, + TimeSeriesSinusoidalPositionalEmbedding, + TimeSeriesStdScaler, + TimeSeriesTransformerDecoder, + TimeSeriesTransformerDecoderLayer, + TimeSeriesTransformerEncoder, + TimeSeriesTransformerEncoderLayer, + TimeSeriesTransformerForPrediction, + TimeSeriesTransformerModel, + TimeSeriesValueEmbedding, +) +from .configuration_informer import InformerConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): + pass + + +class InformerStdScaler(TimeSeriesStdScaler): + pass + + +class InformerMeanScaler(TimeSeriesMeanScaler): + pass + + +class InformerNOPScaler(TimeSeriesNOPScaler): + pass + + +class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding): + pass + + +class InformerValueEmbedding(TimeSeriesValueEmbedding): + pass + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class InformerAttention(BartAttention): + pass + + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.layer_idx = layer_idx + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(nn.Module): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + is_decoder=True, + layer_idx=layer_idx, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + + +class InformerEncoder(TimeSeriesTransformerEncoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.gradient_checkpointing = False + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + if conv_layer is not None: + output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(TimeSeriesTransformerDecoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + +class InformerModel(TimeSeriesTransformerModel, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + super().forward(**super_kwargs) + + +class InformerForPrediction(TimeSeriesTransformerForPrediction, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + @auto_docstring + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + super().forward(**super_kwargs) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 757a393a0b..55ecad4152 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -15,7 +15,7 @@ """PyTorch M2M100 model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -31,31 +31,23 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_m2m_100 import M2M100Config if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -184,6 +176,37 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module): return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100 class M2M100Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -236,17 +259,25 @@ class M2M100Attention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -267,8 +298,8 @@ class M2M100Attention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -280,292 +311,27 @@ class M2M100Attention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->M2M100 -class M2M100FlashAttention2(M2M100Attention): - """ - M2M100 flash attention module. This module inherits from `M2M100Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # M2M100FlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "M2M100SdpaAttention2 attention does not support `output_attentions`. " - "Use the argument `attn_implementation='eager'` when loading the model." - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->M2M100 -class M2M100SdpaAttention(M2M100Attention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 @@ -574,7 +340,7 @@ class M2M100EncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = M2M100Attention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -637,20 +403,13 @@ class M2M100EncoderLayer(nn.Module): return outputs -M2M100_ATTENTION_CLASSES = { - "eager": M2M100Attention, - "flash_attention_2": M2M100FlashAttention2, - "sdpa": M2M100SdpaAttention, -} - - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = M2M100Attention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -664,7 +423,7 @@ class M2M100DecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = M2M100Attention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -732,7 +491,6 @@ class M2M100DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -772,6 +530,8 @@ class M2M100PreTrainedModel(PreTrainedModel): _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _supports_static_cache = False @@ -790,23 +550,55 @@ class M2M100PreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -815,7 +607,7 @@ class M2M100PreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -849,7 +641,6 @@ class M2M100PreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -915,6 +706,42 @@ class M2M100PreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class M2M100Encoder(M2M100PreTrainedModel): """ @@ -951,8 +778,6 @@ class M2M100Encoder(M2M100PreTrainedModel): ) self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1031,18 +856,10 @@ class M2M100Encoder(M2M100PreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1133,8 +950,6 @@ class M2M100Decoder(M2M100PreTrainedModel): self.padding_idx, ) self.layers = nn.ModuleList([M2M100DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1232,23 +1047,27 @@ class M2M100Decoder(M2M100PreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + # initialize `past_key_values` return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): @@ -1277,35 +1096,19 @@ class M2M100Decoder(M2M100PreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) positions = positions.to(inputs_embeds.device) @@ -1313,13 +1116,6 @@ class M2M100Decoder(M2M100PreTrainedModel): hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1351,7 +1147,7 @@ class M2M100Decoder(M2M100PreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1364,7 +1160,7 @@ class M2M100Decoder(M2M100PreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a9a3fd353e..016cb865f8 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -30,7 +30,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -38,7 +40,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -108,6 +111,37 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding): return super().forward(position_ids) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian class MarianAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -160,17 +194,25 @@ class MarianAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -191,8 +233,8 @@ class MarianAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -204,66 +246,27 @@ class MarianAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN @@ -272,7 +275,7 @@ class MarianEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MarianAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -338,16 +341,13 @@ class MarianEncoderLayer(nn.Module): return outputs -MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MarianAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -361,7 +361,7 @@ class MarianDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MarianAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -435,6 +435,7 @@ class MarianDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -465,6 +466,10 @@ class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -484,23 +489,66 @@ class MarianPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + "decoder_input_ids": input_ids, + } + return dummy_inputs + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -509,7 +557,7 @@ class MarianPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -543,7 +591,6 @@ class MarianPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -609,16 +656,41 @@ class MarianPreTrainedModel(PreTrainedModel): return causal_mask - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - "decoder_input_ids": input_ids, - } - return dummy_inputs + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask class MarianEncoder(MarianPreTrainedModel): @@ -734,10 +806,10 @@ class MarianEncoder(MarianPreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -929,12 +1001,18 @@ class MarianDecoder(MarianPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # Important to apply outside of the above `if`, in case user passes `embeds` inputs_embeds = inputs_embeds * self.embed_scale @@ -967,22 +1045,19 @@ class MarianDecoder(MarianPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 3fbb3e8b5b..bdf352a1f6 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +31,9 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -41,7 +43,8 @@ from ...modeling_outputs import ( Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -52,13 +55,7 @@ from .configuration_mbart import MBartConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask - - -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -124,6 +121,37 @@ class MBartScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart class MBartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -176,17 +204,25 @@ class MBartAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -207,8 +243,8 @@ class MBartAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -220,299 +256,27 @@ class MBartAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart -class MBartFlashAttention2(MBartAttention): - """ - MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # MBartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "MBartSdpaAttention2 attention does not support `output_attentions`. " - "Use the argument `attn_implementation='eager'` when loading the model." - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart -class MBartSdpaAttention(MBartAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MBART_ATTENTION_CLASSES = { - "eager": MBartAttention, - "sdpa": MBartSdpaAttention, - "flash_attention_2": MBartFlashAttention2, -} + return attn_output, attn_weights, past_key_value class MBartEncoderLayer(nn.Module): @@ -520,7 +284,7 @@ class MBartEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MBartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -588,7 +352,7 @@ class MBartDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MBartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -602,7 +366,7 @@ class MBartDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MBartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -670,7 +434,6 @@ class MBartDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -735,6 +498,8 @@ class MBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -762,23 +527,55 @@ class MBartPreTrainedModel(PreTrainedModel): } return dummy_inputs - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -787,7 +584,7 @@ class MBartPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -821,7 +618,6 @@ class MBartPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -887,6 +683,42 @@ class MBartPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class MBartEncoder(MBartPreTrainedModel): """ @@ -1007,18 +839,10 @@ class MBartEncoder(MBartPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1214,12 +1038,18 @@ class MBartDecoder(MBartPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) if self.gradient_checkpointing and self.training: if use_cache: @@ -1256,32 +1086,19 @@ class MBartDecoder(MBartPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 00a0afd62a..a0e21f586c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -19,7 +19,7 @@ import inspect import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,7 +40,9 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -48,15 +50,17 @@ from ...modeling_outputs import ( ModelOutput, Seq2SeqLMOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -146,7 +150,38 @@ class MusicgenSinusoidalPositionalEmbedding(nn.Module): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Musicgen +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -181,9 +216,6 @@ class MusicgenAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -192,6 +224,9 @@ class MusicgenAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -199,10 +234,16 @@ class MusicgenAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -217,18 +258,18 @@ class MusicgenAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -240,329 +281,35 @@ class MusicgenAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Musicgen -class MusicgenFlashAttention2(MusicgenAttention): - """ - Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class MusicgenSdpaAttention(MusicgenAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - if ( - attention_mask is not None - and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() - ): - logger.warning_once( - '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information." - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MUSICGEN_ATTENTION_CLASSES = { - "eager": MusicgenAttention, - "sdpa": MusicgenSdpaAttention, - "flash_attention_2": MusicgenFlashAttention2, -} - - class MusicgenDecoderLayer(nn.Module): def __init__(self, config: MusicgenDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MusicgenAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -576,7 +323,7 @@ class MusicgenDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MusicgenAttention( self.embed_dim, config.num_attention_heads, dropout=config.attention_dropout, @@ -590,6 +337,7 @@ class MusicgenDecoderLayer(nn.Module): self.final_layer_norm = nn.LayerNorm(self.embed_dim) # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + # TODO: change to new cache class def forward( self, hidden_states: torch.Tensor, @@ -688,6 +436,8 @@ class MusicgenPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + # compilation errors occurr atm + _supports_flex_attn = False def _init_weights(self, module): std = self.config.initializer_factor @@ -819,40 +569,18 @@ class MusicgenDecoder(MusicgenPreTrainedModel): if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - if self.attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -951,6 +679,80 @@ class MusicgenDecoder(MusicgenPreTrainedModel): cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class MusicgenModel(MusicgenPreTrainedModel): @@ -1559,6 +1361,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # compilation errors occurr atm + _supports_flex_attn = False def __init__( self, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index c5328d0647..3312ad33cd 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -19,7 +19,7 @@ import inspect import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -34,18 +34,25 @@ from ...generation import ( LogitsProcessorList, StoppingCriteriaList, ) -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -159,7 +166,38 @@ class MusicgenMelodySinusoidalPositionalEmbedding(nn.Module): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->MusicgenMelody +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -194,9 +232,6 @@ class MusicgenMelodyAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -205,6 +240,9 @@ class MusicgenMelodyAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -212,10 +250,16 @@ class MusicgenMelodyAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -230,18 +274,18 @@ class MusicgenMelodyAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -253,313 +297,35 @@ class MusicgenMelodyAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->MusicgenMelody -class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): - """ - MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -# Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention with Hubert->MusicgenMelody -class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MUSICGEN_MELODY_ATTENTION_CLASSES = { - "eager": MusicgenMelodyAttention, - "sdpa": MusicgenMelodySdpaAttention, - "flash_attention_2": MusicgenMelodyFlashAttention2, -} - - class MusicgenMelodyDecoderLayer(nn.Module): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MusicgenMelodyAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -644,6 +410,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + # compilation errors occurr atm + _supports_flex_attn = False def _init_weights(self, module): std = self.config.initializer_factor @@ -785,21 +553,12 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel): input_shape = inputs_embeds.size()[:-1] - if self.attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) @@ -881,6 +640,57 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel): attentions=all_attentions, ) + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Ignore copy + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # MusicgenMelody doesn't apply cross attention, hence it's ignored here + # and only exists to not confuse any copy checks + pass + @auto_docstring # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody @@ -1482,6 +1292,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # compilation errors occurr atm + _supports_flex_attn = False def __init__( self, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 3729fd6e18..6d7bd6c985 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -15,7 +15,7 @@ """PyTorch NLLB-MoE model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,18 +25,29 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_nllb_moe import NllbMoeConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -460,7 +471,38 @@ class NllbMoeSparseMLP(nn.Module): return hidden_states, (router_probs, top_1_expert_index) -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->NllbMoe,key_value_states->encoder_hidden_states +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -495,9 +537,6 @@ class NllbMoeAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -506,6 +545,9 @@ class NllbMoeAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -513,10 +555,16 @@ class NllbMoeAttention(nn.Module): # for the decoder is_cross_attention = encoder_hidden_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = encoder_hidden_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -531,18 +579,18 @@ class NllbMoeAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + key_states = self.k_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -554,69 +602,27 @@ class NllbMoeAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class NllbMoeEncoderLayer(nn.Module): @@ -628,6 +634,7 @@ class NllbMoeEncoderLayer(nn.Module): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.attn_dropout = nn.Dropout(config.dropout) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -710,6 +717,7 @@ class NllbMoeDecoderLayer(nn.Module): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -717,7 +725,11 @@ class NllbMoeDecoderLayer(nn.Module): self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = NllbMoeAttention( - self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + config.decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + config=config, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) if not self.is_sparse: @@ -837,6 +849,12 @@ class NllbMoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"] + # TODO: If anyone is up to it to make sure tests pass etc + # Flash attention has problems due to not preparing masks the same way as eager/sdpa + # SDPA has more flaky logits which requires more time to look into tests + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" @@ -975,10 +993,10 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_router_probs = () if output_router_logits else None @@ -1042,6 +1060,29 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): router_probs=all_router_probs, ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class NllbMoeDecoder(NllbMoePreTrainedModel): """ @@ -1195,18 +1236,18 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) @@ -1264,7 +1305,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, @@ -1276,7 +1317,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, @@ -1330,6 +1371,82 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): router_probs=all_router_probs, ) + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class NllbMoeModel(NllbMoePreTrainedModel): diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index c10a55b1b0..8f00e89009 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,9 +24,11 @@ import torch.nn as nn from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, logging -from ...utils.deprecation import deprecate_kwarg from .configuration_patchtsmixer import PatchTSMixerConfig @@ -235,7 +237,38 @@ class PatchTSMixerChannelFeatureMixerBlock(nn.Module): return out -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer class PatchTSMixerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -248,7 +281,6 @@ class PatchTSMixerAttention(nn.Module): bias: bool = True, is_causal: bool = False, config: Optional[PatchTSMixerConfig] = None, - layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -265,23 +297,12 @@ class PatchTSMixerAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal - self.layer_idx = layer_idx - if layer_idx is None and self.is_decoder: - logger.warning_once( - f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " - "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Ignore copy - @deprecate_kwarg("key_value_states", version="4.55") - @deprecate_kwarg("past_key_value", version="4.55") - @deprecate_kwarg("cache_position", version="4.55") def forward( self, hidden_states: torch.Tensor, @@ -290,79 +311,84 @@ class PatchTSMixerAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - bsz, tgt_len, _ = hidden_states.size() + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) else: - attn_weights_reshaped = None + # self_attention + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) - attn_output = torch.bmm(attn_probs, value_states) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, None + return attn_output, attn_weights, past_key_value class PatchMixerBlock(nn.Module): @@ -395,6 +421,7 @@ class PatchMixerBlock(nn.Module): embed_dim=config.d_model, num_heads=config.self_attn_heads, dropout=config.dropout, + config=config, ) self.norm_attn = PatchTSMixerNormLayer(config) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 57b69b6b5e..b85e8a66b2 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -16,14 +16,16 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn from ...activations import ACT2CLS +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ModelOutput, auto_docstring, logging from .configuration_patchtst import PatchTSTConfig @@ -32,7 +34,38 @@ from .configuration_patchtst import PatchTSTConfig logger = logging.get_logger(__name__) -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->PatchTST +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST class PatchTSTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -67,9 +100,6 @@ class PatchTSTAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -78,6 +108,9 @@ class PatchTSTAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -85,10 +118,16 @@ class PatchTSTAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -103,18 +142,18 @@ class PatchTSTAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -126,69 +165,27 @@ class PatchTSTAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class PatchTSTBatchNorm(nn.Module): @@ -461,6 +458,7 @@ class PatchTSTEncoderLayer(nn.Module): embed_dim=config.d_model, num_heads=config.num_attention_heads, dropout=config.attention_dropout, + config=config, ) # Add & Norm of the sublayer 1 diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 3f59a8c918..19166bd609 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -30,7 +30,9 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -38,7 +40,8 @@ from ...modeling_outputs import ( Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -49,9 +52,7 @@ from .configuration_pegasus import PegasusConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -109,6 +110,37 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding): return super().forward(position_ids) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus class PegasusAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -161,17 +193,25 @@ class PegasusAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -192,8 +232,8 @@ class PegasusAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -205,69 +245,27 @@ class PegasusAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -276,7 +274,7 @@ class PegasusEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = PegasusAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -345,7 +343,7 @@ class PegasusDecoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = PegasusAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -359,7 +357,7 @@ class PegasusDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = PegasusAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -427,7 +425,6 @@ class PegasusDecoderLayer(nn.Module): residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, @@ -464,6 +461,12 @@ class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -481,23 +484,55 @@ class PegasusPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -506,7 +541,7 @@ class PegasusPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -540,7 +575,6 @@ class PegasusPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -606,6 +640,42 @@ class PegasusPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PegasusEncoder(PegasusPreTrainedModel): """ @@ -748,10 +818,10 @@ class PegasusEncoder(PegasusPreTrainedModel): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -970,26 +1040,32 @@ class PegasusDecoder(PegasusPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # important to apply scale outside of `if` in case users pass `embeds` inputs_embeds = inputs_embeds * self.embed_scale + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + # initialize `past_key_values` return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): @@ -1018,20 +1094,19 @@ class PegasusDecoder(PegasusPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 04cf37a762..a2fcf5edd1 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -16,7 +16,7 @@ import dataclasses import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -30,14 +30,17 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, is_torch_flex_attn_available, @@ -48,9 +51,7 @@ from .configuration_pegasus_x import PegasusXConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -135,6 +136,37 @@ class PegasusXSinusoidalPositionalEmbedding(nn.Module): return pe[None].expand(batch_size, -1, -1) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX class PegasusXAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -187,17 +219,25 @@ class PegasusXAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -218,8 +258,8 @@ class PegasusXAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -231,66 +271,27 @@ class PegasusXAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class PegasusXGlobalLocalAttention(nn.Module): @@ -653,6 +654,7 @@ class PegasusXDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, layer_idx=layer_idx, ) self.dropout = config.dropout @@ -666,6 +668,7 @@ class PegasusXDecoderLayer(nn.Module): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -758,6 +761,11 @@ class PegasusXPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] + _supports_flash_attn_2 = True + # Flaky logits + _supports_sdpa = False + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True @@ -773,23 +781,55 @@ class PegasusXPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -798,7 +838,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -832,7 +872,6 @@ class PegasusXPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -898,6 +937,42 @@ class PegasusXPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PegasusXEncoder(PegasusXPreTrainedModel): """ @@ -1227,22 +1302,28 @@ class PegasusXDecoder(PegasusXPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -1272,20 +1353,19 @@ class PegasusXDecoder(PegasusXPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + ) # embed positions position_ids = cache_position.unsqueeze(1) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 614baee8bf..85813b0242 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/plbart/modular_plbart.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_plbart.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. # @@ -12,14 +18,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch PLBART model.""" import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -31,6 +35,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,73 +44,19 @@ from ...modeling_outputs import ( Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - is_torchdynamo_compiling, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_plbart import PLBartConfig if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) -# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): - """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not - have a single `decoder_start_token_id` in contrast to other Bart-like models. - """ - prev_output_tokens = input_ids.clone() - - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) - - index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) - decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() - prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() - prev_output_tokens[:, 0] = decoder_start_tokens - - return prev_output_tokens - - -# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart -class PLBartLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int): - # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): - """`input_ids' shape is expected to be [bsz x seqlen].""" - - if position_ids is None: - bsz, seq_len = input_ids.shape[:2] - position_ids = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) - else: - position_ids = position_ids.unsqueeze(0) - - return super().forward(position_ids + self.offset) - - -# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart class PLBartScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. @@ -119,392 +70,16 @@ class PLBartScaledWordEmbedding(nn.Embedding): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart -class PLBartAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, - config: Optional[PLBartConfig] = None, - layer_idx: Optional[int] = None, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal - self.layer_idx = layer_idx - if layer_idx is None and self.is_decoder: - logger.warning_once( - f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " - "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling - - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) - if is_cross_attention: - # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache - else: - curr_past_key_value = past_key_value.self_attention_cache - else: - curr_past_key_value = past_key_value - - current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: - # reuse k,v, cross_attentions - key_states = curr_past_key_value.key_cache[self.layer_idx] - value_states = curr_past_key_value.value_cache[self.layer_idx] - else: - key_states = self.k_proj(current_states) - value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_position = cache_position if not is_cross_attention else None - key_states, value_states = curr_past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls - if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART -class PLBartEncoderLayer(nn.Module): - def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: torch.FloatTensor, - layer_head_mask: torch.FloatTensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - hidden_states, attn_weights, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - if hidden_states.dtype == torch.float16 and ( - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() - ): - clamp_value = torch.finfo(hidden_states.dtype).max - 1000 - hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -# TODO: Implement attention with SDPA for PLBart. -PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention} - - -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART -class PLBartDecoderLayer(nn.Module): - def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - is_causal=True, - config=config, - layer_idx=layer_idx, - ) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - layer_idx=layer_idx, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of - size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. It is used to update the - cache in the correct position and to infer the complete sequence length. - """ - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights, past_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - cache_position=cache_position, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (past_key_value,) - - return outputs - - -# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart -class PLBartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - @auto_docstring class PLBartPreTrainedModel(PreTrainedModel): config_class = PLBartConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] - _supports_cache_class = True - _supports_static_cache = True + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std @@ -516,27 +91,56 @@ class PLBartPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() - # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, - attention_mask: Union[torch.Tensor, "BlockMask"], + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool = False, ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None - if self.config._attn_implementation == "flex_attention": - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail @@ -545,7 +149,7 @@ class PLBartPreTrainedModel(PreTrainedModel): using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -579,7 +183,6 @@ class PLBartPreTrainedModel(PreTrainedModel): self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -645,8 +248,295 @@ class PLBartPreTrainedModel(PreTrainedModel): return causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PLBartLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) + + return super().forward(position_ids + self.offset) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class PLBartAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[PLBartConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class PLBartEncoderLayer(nn.Module): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + -# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart class PLBartEncoder(PLBartPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -680,8 +570,6 @@ class PLBartEncoder(PLBartPreTrainedModel): embed_dim, ) self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -767,18 +655,10 @@ class PLBartEncoder(PLBartPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -835,7 +715,125 @@ class PLBartEncoder(PLBartPreTrainedModel): ) -# Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart +class PLBartDecoderLayer(nn.Module): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + layer_idx=layer_idx, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PLBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + layer_idx=layer_idx, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, past_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + class PLBartDecoder(PLBartPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`] @@ -865,8 +863,6 @@ class PLBartDecoder(PLBartPreTrainedModel): config.d_model, ) self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -981,12 +977,18 @@ class PLBartDecoder(PLBartPreTrainedModel): # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - - if input_ids is not None: - input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` return_legacy_cache = False @@ -1016,38 +1018,25 @@ class PLBartDecoder(PLBartPreTrainedModel): if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - causal_mask = self._update_causal_mask( + + attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=seq_length, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length - ) - # embed positions - position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) - position_ids = position_ids.to(inputs_embeds.device) + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + positions = positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1080,7 +1069,7 @@ class PLBartDecoder(PLBartPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1093,7 +1082,7 @@ class PLBartDecoder(PLBartPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1139,6 +1128,26 @@ class PLBartDecoder(PLBartPreTrainedModel): ) +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that PLBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + + @auto_docstring class PLBartModel(PLBartPreTrainedModel): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] @@ -1192,7 +1201,7 @@ class PLBartModel(PLBartPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1350,7 +1359,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1469,10 +1478,34 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): return reordered_past +class PLBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + @auto_docstring( custom_intro=""" - PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code - classification. + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. + for GLUE tasks. """ ) class PLBartForSequenceClassification(PLBartPreTrainedModel): @@ -1492,7 +1525,6 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): self.post_init() @auto_docstring - # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1621,7 +1653,6 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ) -# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart class PLBartDecoderWrapper(PLBartPreTrainedModel): """ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is @@ -1636,7 +1667,11 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel): return self.decoder(*args, **kwargs) -# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base +@auto_docstring( + custom_intro=""" + PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings). + """ +) class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py new file mode 100644 index 0000000000..8a9755b11b --- /dev/null +++ b/src/transformers/models/plbart/modular_plbart.py @@ -0,0 +1,692 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PLBART model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring, is_torch_flex_attn_available +from ..bart.modeling_bart import ( + BartClassificationHead, + BartDecoder, + BartEncoder, + BartForCausalLM, + BartScaledWordEmbedding, +) +from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification +from ..mbart.modeling_mbart import shift_tokens_right +from .configuration_plbart import PLBartConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + +class PLBartScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +@auto_docstring +class PLBartPreTrainedModel(PreTrainedModel): + config_class = PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + # Compile issues + _supports_flex_attn = False + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +class PLBartEncoder(BartEncoder): + pass + + +class PLBartDecoder(BartDecoder): + pass + + +@auto_docstring +class PLBartModel(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code. + """ +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Mask-filling: + + ```python + >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration + + >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base") + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + + >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci number ? en_XX" + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['first', 'same', 'highest', 'result', 'number'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +class PLBartClassificationHead(BartClassificationHead): + pass + + +class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification): + def forward(**super_kwargs): + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + super().forward(**super_kwargs) + + +class PLBartForCausalLM(BartForCausalLM): + @auto_docstring + def forward(**super_kwargs): + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + super().forward(**super_kwargs) + + +__all__ = [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", +] diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extractor_sew.py new file mode 100644 index 0000000000..c58812b58d --- /dev/null +++ b/src/transformers/models/sew/feature_extractor_sew.py @@ -0,0 +1,34 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sew/modular_sew.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sew.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +from .modeling_sew import SEWFeatureEncoder + + +class SEWFeatureExtractor(SEWFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index b4d8584fe0..330cd99a7b 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sew/modular_sew.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sew.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. # @@ -12,40 +18,614 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch SEW model.""" import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, logging -from ...utils.deprecation import deprecate_kwarg from .configuration_sew import SEWConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) -_HIDDEN_STATES_START_POSITION = 1 +class SEWNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class SEWLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class SEWGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class SEWPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +class SEWSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class SEWUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +class SEWFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [ + SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SEWAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[SEWConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) + + # get query proj + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights, past_key_value + + +class SEWFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class SEWEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = SEWAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + config=config, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = SEWFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SEWEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.upsample = SEWUpsampling(config) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + if self.config._attn_implementation == "flash_attention_2": + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.upsample(hidden_states) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class SEWPreTrainedModel(PreTrainedModel): + config_class = SEWConfig + base_model_prefix = "sew" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False # needs a proper look into the mask creation + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -165,756 +745,6 @@ def _compute_mask_indices( return spec_aug_mask -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW -class SEWNoLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW -class SEWLayerNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - - hidden_states = hidden_states.transpose(-2, -1) - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states.transpose(-2, -1) - - hidden_states = self.activation(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW -class SEWGroupNormConvLayer(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 - self.out_conv_dim = config.conv_dim[layer_id] - - self.conv = nn.Conv1d( - self.in_conv_dim, - self.out_conv_dim, - kernel_size=config.conv_kernel[layer_id], - stride=config.conv_stride[layer_id], - bias=config.conv_bias, - ) - self.activation = ACT2FN[config.feat_extract_activation] - - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.activation(hidden_states) - return hidden_states - - -class SEWPositionalConvEmbedding(nn.Module): - def __init__(self, config): - super().__init__() - self.conv = nn.Conv1d( - config.hidden_size, - config.hidden_size, - kernel_size=config.num_conv_pos_embeddings, - padding=config.num_conv_pos_embeddings // 2, - groups=config.num_conv_pos_embedding_groups, - stride=config.squeeze_factor, - ) - - weight_norm = nn.utils.weight_norm - if hasattr(nn.utils.parametrizations, "weight_norm"): - weight_norm = nn.utils.parametrizations.weight_norm - - if is_deepspeed_zero3_enabled(): - import deepspeed - - with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): - self.conv = weight_norm(self.conv, name="weight", dim=2) - if hasattr(self.conv, "parametrizations"): - weight_g = self.conv.parametrizations.weight.original0 - weight_v = self.conv.parametrizations.weight.original1 - else: - weight_g = self.conv.weight_g - weight_v = self.conv.weight_v - deepspeed.zero.register_external_parameter(self, weight_v) - deepspeed.zero.register_external_parameter(self, weight_g) - else: - self.conv = weight_norm(self.conv, name="weight", dim=2) - - self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) - self.activation = ACT2FN[config.feat_extract_activation] - - def forward(self, hidden_states): - hidden_states = self.conv(hidden_states) - hidden_states = self.padding(hidden_states) - hidden_states = self.activation(hidden_states) - - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW -class SEWSamePadLayer(nn.Module): - def __init__(self, num_conv_pos_embeddings): - super().__init__() - self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 - - def forward(self, hidden_states): - if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] - return hidden_states - - -class SEWUpsampling(nn.Module): - def __init__(self, config): - super().__init__() - self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) - self.activation = ACT2FN[config.feat_extract_activation] - self.squeeze_factor = config.squeeze_factor - - def forward(self, hidden_states): - hidden_states = self.projection(hidden_states) - hidden_states = self.activation(hidden_states) - - if self.squeeze_factor > 1: - # transform embedding channels to sequence length - bsz, src_len, src_embed_dim = hidden_states.size() - tgt_len = src_len * self.squeeze_factor - tgt_embed_dim = src_embed_dim // self.squeeze_factor - hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) - hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) - - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW -class SEWFeatureEncoder(nn.Module): - """Construct the features from raw audio waveform""" - - def __init__(self, config): - super().__init__() - - if config.feat_extract_norm == "group": - conv_layers = [SEWGroupNormConvLayer(config, layer_id=0)] + [ - SEWNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) - ] - elif config.feat_extract_norm == "layer": - conv_layers = [SEWLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] - else: - raise ValueError( - f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" - ) - self.conv_layers = nn.ModuleList(conv_layers) - self.gradient_checkpointing = False - self._requires_grad = True - - def _freeze_parameters(self): - for param in self.parameters(): - param.requires_grad = False - self._requires_grad = False - - def forward(self, input_values): - hidden_states = input_values[:, None] - - # make sure hidden_states require grad for gradient_checkpointing - if self._requires_grad and self.training: - hidden_states.requires_grad = True - - for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) - - return hidden_states - - -class SEWFeatureExtractor(SEWFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) - - -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW -class SEWAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - is_causal: bool = False, - config: Optional[SEWConfig] = None, - layer_idx: Optional[int] = None, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - self.config = config - - if (self.head_dim * num_heads) != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) - self.scaling = self.head_dim**-0.5 - self.is_decoder = is_decoder - self.is_causal = is_causal - self.layer_idx = layer_idx - if layer_idx is None and self.is_decoder: - logger.warning_once( - f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " - "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - # Ignore copy - @deprecate_kwarg("key_value_states", version="4.55") - @deprecate_kwarg("past_key_value", version="4.55") - @deprecate_kwarg("cache_position", version="4.55") - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling - - key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, None - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW -class SEWFlashAttention2(SEWAttention): - """ - SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - # Ignore copy - @deprecate_kwarg("key_value_states", version="4.55") - @deprecate_kwarg("past_key_value", version="4.55") - @deprecate_kwarg("cache_position", version="4.55") - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # SEWFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError( - "SEWSdpaAttention2 attention does not support `output_attentions`. " - "Use the argument `attn_implementation='eager'` when loading the model." - ) - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) - - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim) - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.out_proj(attn_output) - - return attn_output, None, None - - -class SEWSdpaAttention(SEWAttention): - @deprecate_kwarg("key_value_states", version="4.55") - @deprecate_kwarg("past_key_value", version="4.55") - @deprecate_kwarg("cache_position", version="4.55") - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - cache_position: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - - causal_mask = None - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) - attn_output = self.out_proj(attn_output) - - return attn_output, None, None - - -SEW_ATTENTION_CLASSES = { - "eager": SEWAttention, - "sdpa": SEWSdpaAttention, - "flash_attention_2": SEWFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW -class SEWFeedForward(nn.Module): - def __init__(self, config): - super().__init__() - self.intermediate_dropout = nn.Dropout(config.activation_dropout) - - self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.output_dropout = nn.Dropout(config.hidden_dropout) - - def forward(self, hidden_states): - hidden_states = self.intermediate_dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - hidden_states = self.intermediate_dropout(hidden_states) - - hidden_states = self.output_dense(hidden_states) - hidden_states = self.output_dropout(hidden_states) - return hidden_states - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW -class SEWEncoderLayer(nn.Module): - def __init__(self, config): - super().__init__() - self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=config.hidden_size, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=False, - ) - - self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.feed_forward = SEWFeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - def forward(self, hidden_states, attention_mask=None, output_attentions=False): - attn_residual = hidden_states - hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = self.dropout(hidden_states) - hidden_states = attn_residual + hidden_states - - hidden_states = self.layer_norm(hidden_states) - hidden_states = hidden_states + self.feed_forward(hidden_states) - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs - - -class SEWEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.pos_conv_embed = SEWPositionalConvEmbedding(config) - self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) - self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.dropout = nn.Dropout(config.hidden_dropout) - self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.upsample = SEWUpsampling(config) - self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - - def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - if attention_mask is not None: - expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - if self._use_flash_attention_2: - # make sure padded tokens output 0 - hidden_states[~expand_attention_mask] = 0.0 - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # make sure padded tokens output 0 - hidden_states[~expand_attention_mask] = 0.0 - input_lengths = (attention_mask.long()).sum(-1) - # apply pooling formula to get real output_lengths - output_lengths = input_lengths // self.config.squeeze_factor - max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor - attention_ids = ( - torch.arange(0, max_encoder_length, device=output_lengths.device) - .view(1, -1) - .expand(output_lengths.shape[0], -1) - ) - attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() - - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) - - n_input_timesteps = hidden_states.shape[1] - - hidden_states = hidden_states.transpose(1, 2) - position_embeddings = self.pos_conv_embed(hidden_states) - pooled_hidden_states = self.pool(hidden_states) - min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) - hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] - hidden_states = hidden_states.transpose(1, 2) - - hidden_states = self.layer_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) - - for layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = torch.rand([]) - - skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or synced_gpus: - # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) - hidden_states = layer_outputs[0] - - if skip_the_layer: - layer_outputs = (None, None) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - hidden_states = self.upsample(hidden_states) - if hidden_states.shape[1] < n_input_timesteps: - hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -@auto_docstring -class SEWPreTrainedModel(PreTrainedModel): - config_class = SEWConfig - base_model_prefix = "sew" - main_input_name = "input_values" - supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, SEWPositionalConvEmbedding): - nn.init.normal_( - module.conv.weight, - mean=0, - std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), - ) - nn.init.constant_(module.conv.bias, 0) - elif isinstance(module, nn.Linear): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, nn.Conv1d): - if is_deepspeed_zero3_enabled(): - import deepspeed - - if hasattr(module, "weight_v") and hasattr(module, "weight_g"): - with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) - else: - with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) - else: - nn.init.kaiming_normal_(module.weight.data) - - if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.data.zero_() - - def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): - """ - Computes the output length of the convolutional layers - """ - - def _conv_out_length(input_length, kernel_size, stride): - # 1D convolutional layer output length formula taken - # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 - - for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): - input_lengths = _conv_out_length(input_lengths, kernel_size, stride) - - return input_lengths - - def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) - batch_size = attention_mask.shape[0] - - attention_mask = torch.zeros( - (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device - ) - # these two operations makes sure that all values before the output lengths idxs are attended to - attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 - attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() - return attention_mask - - @auto_docstring class SEWModel(SEWPreTrainedModel): def __init__(self, config: SEWConfig): @@ -1038,12 +868,14 @@ class SEWModel(SEWPreTrainedModel): ) +_HIDDEN_STATES_START_POSITION = 1 + + @auto_docstring( custom_intro=""" SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """ ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForCTC(SEWPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): r""" @@ -1196,11 +1028,10 @@ class SEWForCTC(SEWPreTrainedModel): @auto_docstring( custom_intro=""" - SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB - Keyword Spotting. + SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. """ ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForSequenceClassification(SEWPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py new file mode 100644 index 0000000000..b3aa3e01b6 --- /dev/null +++ b/src/transformers/models/sew/modular_sew.py @@ -0,0 +1,469 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SEW model.""" + +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Attention, + Wav2Vec2EncoderLayer, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeedForward, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2GroupNormConvLayer, + Wav2Vec2LayerNormConvLayer, + Wav2Vec2NoLayerNormConvLayer, + Wav2Vec2SamePadLayer, + _compute_mask_indices, +) +from .configuration_sew import SEWConfig + + +_HIDDEN_STATES_START_POSITION = 1 + + +class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer): + pass + + +class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer): + pass + + +class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer): + pass + + +class SEWPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +class SEWSamePadLayer(Wav2Vec2SamePadLayer): + pass + + +class SEWUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +class SEWFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class SEWFeatureExtractor(SEWFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class SEWAttention(Wav2Vec2Attention): + pass + + +class SEWFeedForward(Wav2Vec2FeedForward): + pass + + +class SEWEncoderLayer(Wav2Vec2EncoderLayer): + pass + + +class SEWEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.upsample = SEWUpsampling(config) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + if self.config._attn_implementation == "flash_attention_2": + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.upsample(hidden_states) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class SEWPreTrainedModel(PreTrainedModel): + config_class = SEWConfig + base_model_prefix = "sew" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False # needs a proper look into the mask creation + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +@auto_docstring +class SEWModel(SEWPreTrainedModel): + def __init__(self, config: SEWConfig): + super().__init__(config) + self.config = config + self.feature_extractor = SEWFeatureEncoder(config) + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + + self.project_features = config.conv_dim[-1] != config.hidden_size + if self.project_features: + self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.feature_dropout = nn.Dropout(config.feat_proj_dropout) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = SEWEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = self.layer_norm(extract_features) + + if self.project_features: + extract_features = self.feature_projection(extract_features) + hidden_states = self.feature_dropout(extract_features) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SEWForCTC(Wav2Vec2ForCTC): + pass + + +class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification): + pass + + +__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 8db0674633..4acee66424 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,7 +15,7 @@ """PyTorch Speech2Text model.""" import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -23,21 +23,33 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, + is_torch_flex_attn_available, logging, ) from .configuration_speech_to_text import Speech2TextConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -161,7 +173,38 @@ class Speech2TextSinusoidalPositionalEmbedding(nn.Module): return incremental_indices.long() + padding_idx -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Speech2Text +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -196,9 +239,6 @@ class Speech2TextAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -207,6 +247,9 @@ class Speech2TextAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -214,10 +257,16 @@ class Speech2TextAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -232,18 +281,18 @@ class Speech2TextAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -255,72 +304,27 @@ class Speech2TextAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -329,7 +333,7 @@ class Speech2TextEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Speech2TextAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -393,12 +397,13 @@ class Speech2TextEncoderLayer(nn.Module): # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +# TODO: change copy when applying cache class class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Speech2TextAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -411,7 +416,7 @@ class Speech2TextDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = Speech2TextAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -423,6 +428,7 @@ class Speech2TextDecoderLayer(nn.Module): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, @@ -519,6 +525,12 @@ class Speech2TextPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True + # TODO: tests would need a rewrite to check for correct implementation + # Current tests always assume certain inputs to be passed + _supports_flash_attn_2 = False + _supports_sdpa = False + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std @@ -655,10 +667,10 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -713,6 +725,29 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Speech2TextDecoder(Speech2TextPreTrainedModel): """ @@ -857,16 +892,18 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) @@ -963,6 +1000,82 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class Speech2TextModel(Speech2TextPreTrainedModel): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 4936ae5636..3bc19a75b3 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Time Series Transformer model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -23,7 +23,13 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -31,12 +37,17 @@ from ...modeling_outputs import ( Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_time_series_transformer import TimeSeriesTransformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -265,6 +276,37 @@ class TimeSeriesValueEmbedding(nn.Module): return self.value_projection(x) +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->TimeSeriesTransformer class TimeSeriesTransformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -317,17 +359,25 @@ class TimeSeriesTransformerAttention(nn.Module): layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - query_states = query_states * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) if past_key_value is not None: if isinstance(past_key_value, EncoderDecoderCache): @@ -348,8 +398,8 @@ class TimeSeriesTransformerAttention(nn.Module): else: key_states = self.k_proj(current_states) value_states = self.v_proj(current_states) - key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation @@ -361,66 +411,27 @@ class TimeSeriesTransformerAttention(nn.Module): if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = query_states.reshape(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER @@ -429,7 +440,7 @@ class TimeSeriesTransformerEncoderLayer(nn.Module): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = TimeSeriesTransformerAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -495,19 +506,13 @@ class TimeSeriesTransformerEncoderLayer(nn.Module): return outputs -# TODO: Implement attention with SDPA for TimeSeriesTransformer. -TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = { - "eager": TimeSeriesTransformerAttention, -} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = TimeSeriesTransformerAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -521,7 +526,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = TimeSeriesTransformerAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -595,6 +600,7 @@ class TimeSeriesTransformerDecoderLayer(nn.Module): layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -626,6 +632,12 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "past_values" supports_gradient_checkpointing = True + # TODO: tests would need a rewrite to check for correct implementation + # Current tests always assume certain inputs to be passed + _supports_flash_attn_2 = False + _supports_sdpa = False + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std @@ -640,6 +652,105 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + ): + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif self.config._attn_implementation == "flex_attention": + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): """ @@ -718,10 +829,10 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -907,16 +1018,18 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) @@ -966,6 +1079,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1333,7 +1447,16 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - dec_input = transformer_inputs[:, self.config.context_length :, ...] + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 071db33fdc..07ee6608b7 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -4,10 +4,25 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_unispeech.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -17,7 +32,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,13 +41,14 @@ from ...modeling_outputs import ( SequenceClassifierOutput, Wav2Vec2BaseModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_unispeech import UniSpeechConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -263,6 +280,36 @@ class UniSpeechFeatureProjection(nn.Module): return hidden_states, norm_hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class UniSpeechAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -297,9 +344,6 @@ class UniSpeechAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -308,6 +352,9 @@ class UniSpeechAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -315,10 +362,16 @@ class UniSpeechAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -333,18 +386,18 @@ class UniSpeechAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -356,298 +409,29 @@ class UniSpeechAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class UniSpeechFlashAttention2(UniSpeechAttention): - """ - UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class UniSpeechSdpaAttention(UniSpeechAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class UniSpeechFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -672,21 +456,15 @@ class UniSpeechFeedForward(nn.Module): return hidden_states -UNISPEECH_ATTENTION_CLASSES = { - "eager": UniSpeechAttention, - "sdpa": UniSpeechSdpaAttention, - "flash_attention_2": UniSpeechFlashAttention2, -} - - class UniSpeechEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -723,7 +501,6 @@ class UniSpeechEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -740,16 +517,11 @@ class UniSpeechEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -798,6 +570,28 @@ class UniSpeechEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -827,11 +621,12 @@ class UniSpeechAttnAdapterLayer(nn.Module): class UniSpeechEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -880,7 +675,6 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): [UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -894,19 +688,14 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -957,6 +746,28 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechGumbelVectorQuantizer(nn.Module): """ @@ -1036,6 +847,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 9c1eae48e8..795ab85967 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UniSpeech model.""" + import math import warnings from dataclasses import dataclass @@ -135,6 +151,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 8e67b41a77..8d9ac9c33f 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -4,10 +4,25 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_unispeech_sat.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -17,7 +32,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -27,13 +43,14 @@ from ...modeling_outputs import ( Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_peft_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging from .configuration_unispeech_sat import UniSpeechSatConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -266,6 +283,36 @@ class UniSpeechSatFeatureProjection(nn.Module): return hidden_states, norm_hidden_states +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class UniSpeechSatAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -300,9 +347,6 @@ class UniSpeechSatAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -311,6 +355,9 @@ class UniSpeechSatAttention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -318,10 +365,16 @@ class UniSpeechSatAttention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -336,18 +389,18 @@ class UniSpeechSatAttention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -359,298 +412,29 @@ class UniSpeechSatAttention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class UniSpeechSatFlashAttention2(UniSpeechSatAttention): - """ - UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class UniSpeechSatSdpaAttention(UniSpeechSatAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class UniSpeechSatFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -675,21 +459,15 @@ class UniSpeechSatFeedForward(nn.Module): return hidden_states -UNISPEECH_SAT_ATTENTION_CLASSES = { - "eager": UniSpeechSatAttention, - "sdpa": UniSpeechSatSdpaAttention, - "flash_attention_2": UniSpeechSatFlashAttention2, -} - - class UniSpeechSatEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechSatAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -726,7 +504,6 @@ class UniSpeechSatEncoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -743,16 +520,11 @@ class UniSpeechSatEncoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -801,6 +573,28 @@ class UniSpeechSatEncoder(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechSatAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -830,11 +624,12 @@ class UniSpeechSatAttnAdapterLayer(nn.Module): class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechSatAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -883,7 +678,6 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): [UniSpeechSatEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -897,19 +691,14 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -960,6 +749,28 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechSatGumbelVectorQuantizer(nn.Module): """ @@ -1039,6 +850,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index 22de13c3bd..9f9e7d4f3c 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UniSpeechSat model.""" + import math import warnings from dataclasses import dataclass @@ -145,6 +161,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 7eb38da9ac..fb01234e3f 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -38,7 +42,8 @@ from ...modeling_outputs import ( Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( ModelOutput, auto_docstring, @@ -46,6 +51,7 @@ from ...utils import ( check_torch_load_is_safe, is_peft_available, is_safetensors_available, + is_torch_flex_attn_available, logging, ) from .configuration_wav2vec2 import Wav2Vec2Config @@ -58,8 +64,8 @@ if is_safetensors_available(): from safetensors.torch import load_file as safe_load_file -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -465,7 +471,37 @@ class Wav2Vec2FeatureProjection(nn.Module): return hidden_states, norm_hidden_states -# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Wav2Vec2 +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Wav2Vec2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -500,9 +536,6 @@ class Wav2Vec2Attention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -511,6 +544,9 @@ class Wav2Vec2Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -518,10 +554,16 @@ class Wav2Vec2Attention(nn.Module): # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -536,18 +578,18 @@ class Wav2Vec2Attention(nn.Module): value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -559,307 +601,29 @@ class Wav2Vec2Attention(nn.Module): # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Wav2Vec2 -class Wav2Vec2FlashAttention2(Wav2Vec2Attention): - """ - Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + output_attentions=output_attentions, + head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Wav2Vec2SdpaAttention(Wav2Vec2Attention): - # Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention.forward with Hubert->Wav2Vec2 - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -WAV2VEC2_ATTENTION_CLASSES = { - "eager": Wav2Vec2Attention, - "sdpa": Wav2Vec2SdpaAttention, - "flash_attention_2": Wav2Vec2FlashAttention2, -} - - class Wav2Vec2FeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -887,11 +651,12 @@ class Wav2Vec2FeedForward(nn.Module): class Wav2Vec2EncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -922,11 +687,12 @@ class Wav2Vec2EncoderLayer(nn.Module): class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -973,7 +739,6 @@ class Wav2Vec2Encoder(nn.Module): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -990,16 +755,11 @@ class Wav2Vec2Encoder(nn.Module): # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -1048,6 +808,29 @@ class Wav2Vec2Encoder(nn.Module): attentions=all_self_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Wav2Vec2EncoderStableLayerNorm(nn.Module): def __init__(self, config): @@ -1060,7 +843,6 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -1074,19 +856,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -1137,6 +914,29 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): attentions=all_self_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa": + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Wav2Vec2GumbelVectorQuantizer(nn.Module): """ @@ -1296,6 +1096,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index fc660d06ac..af02fa91ee 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -26,7 +26,11 @@ from ...modeling_outputs import ( XVectorOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import ModelOutput, auto_docstring, is_peft_available +from ...utils import ( + ModelOutput, + auto_docstring, + is_peft_available, +) from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 90330c2645..11670ea7d2 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -617,6 +617,7 @@ class WavLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index b74d9fc703..1ff9d5052c 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -527,6 +527,7 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e42c7ce306..ef847d6059 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -27,7 +27,10 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + flash_attn_supports_top_left_mask, + is_flash_attn_available, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -259,6 +262,9 @@ class WhisperAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def forward( self, hidden_states: torch.Tensor, @@ -572,7 +578,8 @@ WHISPER_ATTENTION_CLASSES = { } -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER +# (BC Dep) Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER +# TODO(vasqu): fix copies when enabling whisper attn interface class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index dc1e268cbe..e9f3f63c96 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -282,7 +282,7 @@ class XGLMDecoderLayer(nn.Module): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 783e7cb3a7..57037ee435 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1148,6 +1148,10 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -1228,6 +1232,10 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -1282,6 +1290,10 @@ class GenerationTesterMixin: # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, inputs_dict = self.prepare_config_and_inputs_for_generate() + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # Encoder-decoder models are not supported if config.is_encoder_decoder: self.skipTest("DoLa is not supported for encoder-decoder models") @@ -1346,6 +1358,10 @@ class GenerationTesterMixin: # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 34fbcd80f9..ded8d5f0a8 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1521,3 +1521,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index 83a5c73ff7..bec16cf5dc 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -554,3 +554,7 @@ class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix @unittest.skip(reason="decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 5e9163376a..8d75649d8c 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -563,3 +563,7 @@ class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTest @unittest.skip(reason="decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py index 96e970beb6..5a8f410a70 100644 --- a/tests/models/data2vec/test_modeling_data2vec_audio.py +++ b/tests/models/data2vec/test_modeling_data2vec_audio.py @@ -420,6 +420,9 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 6df9337413..9aceea0359 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -412,6 +412,10 @@ class EncoderDecoderMixin: labels, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -445,6 +449,10 @@ class EncoderDecoderMixin: # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded # from the inner models' configurations. + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py index c4e65cfa5c..de26f4c7a4 100644 --- a/tests/models/hubert/test_modeling_hubert.py +++ b/tests/models/hubert/test_modeling_hubert.py @@ -370,6 +370,9 @@ class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -632,6 +635,9 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index c264de826c..ba6e8f9c25 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -357,7 +357,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device + [[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -374,7 +375,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device + [[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -426,7 +428,7 @@ class M2M100ModelIntegrationTests(unittest.TestCase): Overwriting the common test as the test is flaky on tiny models """ model = M2M100ForConditionalGeneration.from_pretrained( - "facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2" + "facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).to(torch_device) tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index a53627852f..ed42b1b29f 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -850,3 +850,7 @@ class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 3b328cfe9e..4ef22c3c30 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -735,3 +735,7 @@ class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, u @unittest.skip(reason="Decoder cannot retain gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot retain gradients") + def test_flex_attention_with_grads(self): + return diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index b14df8de0e..1a27192506 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -728,6 +728,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # force eager attention to support output attentions + config._attn_implementation = "eager" + for model_class in self.all_model_classes: self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict) self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict) @@ -805,6 +808,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, config.text_encoder.output_attentions = True config.decoder.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -1036,30 +1042,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - ).to(torch_device) - - for _, module in model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - return - - self.assertTrue(False, "FlashAttention2 modules not found in model") + self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.") @require_torch_sdpa @require_torch_gpu @@ -1234,18 +1217,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - def test_requires_grad_with_frozen_encoders(self): config = self.model_tester.get_config() for model_class in self.all_model_classes: @@ -1276,6 +1247,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_generation_tester_mixin_inheritance(self): pass + @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.")) + def test_sdpa_can_compile_dynamic(self): + pass + def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): """Produces a series of 'bip bip' sounds at a given frequency.""" diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 72cbb990c9..abf2edd1ce 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -731,6 +731,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # force eager attention to support output attentions + config._attn_implementation = "eager" + for model_class in self.all_model_classes: self.check_musicgen_melody_model_output_attentions(model_class, config, **inputs_dict) self.check_musicgen_melody_model_output_attentions_from_config(model_class, config, **inputs_dict) @@ -807,6 +810,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester config.text_encoder.output_attentions = True config.decoder.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -1036,30 +1042,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - ).to(torch_device) - - for _, module in model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - return - - self.assertTrue(False, "FlashAttention2 modules not found in model") + self.skipTest(reason="MusicgenMelody doesn't use the MusicgenMelodyFlashAttention2 class method.") @require_torch_sdpa @require_torch_gpu @@ -1234,18 +1217,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - def test_requires_grad_with_frozen_encoders(self): config = self.model_tester.get_config() for model_class in self.all_model_classes: @@ -1276,6 +1247,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_generation_tester_mixin_inheritance(self): pass + @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.")) + def test_sdpa_can_compile_dynamic(self): + pass + # Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): diff --git a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py index ad00eab111..9b48a2e307 100644 --- a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py +++ b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py @@ -480,7 +480,7 @@ class PatchTSMixerModelIntegrationTests(unittest.TestCase): ) self.assertEqual(output.shape, expected_shape) - expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip + expected_slice = torch.tensor([[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],],device=torch_device) # fmt: skip torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) def test_forecasting_head(self): diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index 2b7a127d48..af119c41d3 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -597,3 +597,7 @@ class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 2ffdc46367..a6bf913e4c 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -590,7 +590,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device + [[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]]], device=torch_device ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -608,7 +608,8 @@ class PegasusXModelIntegrationTests(unittest.TestCase): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device + [[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -635,8 +636,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase): batch_input, max_length=512, padding="max_length", - truncation_strategy="only_first", - truncation=True, + truncation="only_first", return_tensors="pt", ) @@ -872,3 +872,7 @@ class PegasusXStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 569a63d435..1797505828 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -677,3 +677,7 @@ class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py index 9a7430c641..2cab21cf5c 100644 --- a/tests/models/sew/test_modeling_sew.py +++ b/tests/models/sew/test_modeling_sew.py @@ -342,6 +342,9 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 7ac8d5631d..28cdaf3447 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -300,6 +300,10 @@ class EncoderDecoderMixin: input_features=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] diff --git a/tests/models/unispeech/test_modeling_unispeech.py b/tests/models/unispeech/test_modeling_unispeech.py index 0f8e244832..ebc537a478 100644 --- a/tests/models/unispeech/test_modeling_unispeech.py +++ b/tests/models/unispeech/test_modeling_unispeech.py @@ -383,6 +383,9 @@ class UniSpeechRobustModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py index dd8deafc22..ec438dea96 100644 --- a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py +++ b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py @@ -423,6 +423,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -632,6 +635,9 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index bfd9aad233..ffd08297f1 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -246,6 +246,10 @@ class EncoderDecoderMixin: pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -480,6 +484,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -670,6 +678,10 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase): pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -807,6 +819,10 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase): labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -929,6 +945,10 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase): labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -1047,6 +1067,10 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase): labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index c47d4855cc..9597d2e6ef 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -570,6 +570,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -917,6 +920,9 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2f16a86d80..a85c9e7e62 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -958,6 +958,8 @@ class ModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" seq_len = getattr(self.model_tester, "seq_length", None) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) @@ -1106,7 +1108,11 @@ class ModelTesterMixin: configs_no_init.torchscript = True for model_class in self.all_model_classes: for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): + if ( + attn_implementation == "sdpa" + and (not model_class._supports_sdpa or not is_torch_sdpa_available()) + or config.output_attentions + ): continue configs_no_init._attn_implementation = attn_implementation @@ -1708,6 +1714,10 @@ class ModelTesterMixin: config.output_hidden_states = True config.output_attentions = self.has_attentions + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -4555,13 +4565,26 @@ class ModelTesterMixin: # TODO: raushan, fix for composite models after making VLMs support new attn API if not model_class._supports_flex_attn or self._is_composite: self.skipTest(reason="This model does not support flex attention") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" - model = model_class(config).to(device=torch_device, dtype=torch.float16) + # Flex Attention can not use dropout + if hasattr(config, "attention_dropout"): + config.attention_dropout = 0 + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = 0 + + model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"] + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"] + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) - _ = model(inputs_dict["input_ids"].to(torch_device)) + _ = model(**dummy_inputs) def test_generation_tester_mixin_inheritance(self): """