From ac5893756bafcd745d93a442cf36f984545dbad8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 27 Oct 2023 16:42:01 +0200 Subject: [PATCH] [Attention Mask] Refactor all encoder-decoder attention mask (#27086) * [FA2 Bart] Add FA2 to all Bart-like * better * Refactor attention mask * remove all customized atteniton logic * format * mass rename * replace _expand_mask * replace _expand_mask * mass rename * add pt files * mass replace & rename * mass replace & rename * mass replace & rename * mass replace & rename * Update src/transformers/models/idefics/modeling_idefics.py * fix more * clean more * fix more * make style * fix again * finish * finish * finish * finish * finish * finish * finish * finish * finish * finish * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * small fix mistral * finish * finish * finish * finish --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_attn_mask_utils.py | 247 ++++++++++++++++++ .../models/autoformer/modeling_autoformer.py | 63 +---- src/transformers/models/bart/modeling_bart.py | 63 +---- .../modeling_bigbird_pegasus.py | 65 +---- .../models/biogpt/modeling_biogpt.py | 60 +---- .../models/blenderbot/modeling_blenderbot.py | 66 +---- .../modeling_blenderbot_small.py | 66 +---- .../models/bloom/modeling_bloom.py | 57 +--- src/transformers/models/clip/modeling_clip.py | 40 +-- .../models/clipseg/modeling_clipseg.py | 40 +-- .../modeling_conditional_detr.py | 35 +-- .../modeling_deformable_detr.py | 18 +- .../models/deprecated/mctct/modeling_mctct.py | 18 +- .../open_llama/modeling_open_llama.py | 64 +---- src/transformers/models/deta/modeling_deta.py | 18 +- src/transformers/models/detr/modeling_detr.py | 25 +- .../models/falcon/modeling_falcon.py | 181 +------------ src/transformers/models/git/modeling_git.py | 22 +- .../models/groupvit/modeling_groupvit.py | 40 +-- .../models/idefics/modeling_idefics.py | 59 +---- .../models/informer/modeling_informer.py | 65 +---- src/transformers/models/led/modeling_led.py | 32 +-- .../models/llama/modeling_llama.py | 164 +----------- .../models/m2m_100/modeling_m2m_100.py | 57 +--- .../models/marian/modeling_marian.py | 66 +---- .../mask2former/modeling_mask2former.py | 15 -- .../models/maskformer/modeling_maskformer.py | 32 +-- .../models/mbart/modeling_mbart.py | 66 +---- .../models/mistral/modeling_mistral.py | 177 +------------ src/transformers/models/mpt/modeling_mpt.py | 70 +---- .../models/musicgen/modeling_musicgen.py | 64 +---- src/transformers/models/mvp/modeling_mvp.py | 63 +---- .../models/nllb_moe/modeling_nllb_moe.py | 57 +--- src/transformers/models/opt/modeling_opt.py | 59 +---- .../models/owlv2/modeling_owlv2.py | 41 +-- .../models/owlvit/modeling_owlvit.py | 41 +-- .../models/pegasus/modeling_pegasus.py | 66 +---- .../models/pegasus_x/modeling_pegasus_x.py | 89 +------ .../models/persimmon/modeling_persimmon.py | 60 +---- .../models/plbart/modeling_plbart.py | 65 +---- .../seamless_m4t/modeling_seamless_m4t.py | 68 +---- .../speech_to_text/modeling_speech_to_text.py | 64 +---- .../modeling_speech_to_text_2.py | 61 +---- .../models/speecht5/modeling_speecht5.py | 66 +---- .../modeling_table_transformer.py | 25 +- .../modeling_time_series_transformer.py | 65 +---- .../models/trocr/modeling_trocr.py | 61 +---- src/transformers/models/vits/modeling_vits.py | 18 +- .../models/whisper/modeling_whisper.py | 58 +--- .../models/x_clip/modeling_x_clip.py | 40 +-- src/transformers/models/xglm/modeling_xglm.py | 61 +---- ...ng_{{cookiecutter.lowercase_modelname}}.py | 57 +--- tests/models/llama/test_modeling_llama.py | 143 ---------- tests/test_modeling_utils.py | 142 ++++++++++ utils/check_inits.py | 1 + 55 files changed, 647 insertions(+), 2879 deletions(-) create mode 100755 src/transformers/modeling_attn_mask_utils.py diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py new file mode 100755 index 0000000000..434b32ce7f --- /dev/null +++ b/src/transformers/modeling_attn_mask_utils.py @@ -0,0 +1,247 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch + + +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: Optional[int] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 12a0951c88..92e9df2c7e 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,6 +26,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, ModelOutput, @@ -357,39 +358,6 @@ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch. return -input.log_prob(target) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Autoformer class AutoformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -1176,7 +1144,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1262,29 +1230,6 @@ class AutoformerDecoder(AutoformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ).to(inputs_embeds.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, trend: Optional[torch.Tensor] = None, @@ -1374,7 +1319,9 @@ class AutoformerDecoder(AutoformerPreTrainedModel): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 20ada97627..c271aabcb4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -86,37 +87,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class BartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -823,7 +793,7 @@ class BartEncoder(BartPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -919,29 +889,6 @@ class BartDecoder(BartPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1048,14 +995,16 @@ class BartDecoder(BartPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e7841d8f59..cc69e31e5a 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -77,38 +78,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -1902,7 +1871,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): # expand attention_mask if self.attention_type == "original_full": # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) blocked_encoder_mask = band_mask = from_mask = to_mask = None elif self.attention_type == "block_sparse": blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( @@ -2101,30 +2070,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2229,14 +2174,16 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 157607b73c..2cbfa4ef0f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -53,39 +54,6 @@ BIOGPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt class BioGptLearnedPositionalEmbedding(nn.Embedding): """ @@ -472,30 +440,6 @@ class BioGptModel(BioGptPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(BIOGPT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -554,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel): # embed positions positions = self.embed_positions(attention_mask, past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 4a1a7d0732..985c660fa0 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -75,39 +76,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class BlenderbotLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -747,7 +715,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -845,30 +813,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -974,14 +918,16 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 5755576b4d..3d51ee9128 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -71,39 +72,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.blenderbot.modeling_blenderbot.BlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding): """ @@ -745,7 +713,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -840,30 +808,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -968,14 +912,16 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index f94e371256..3064d5bae6 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -53,36 +54,6 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def _make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) - # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround - seq_ids = torch.arange(target_length, device=device) - mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] - - if past_key_values_length > 0: - mask[:, :past_key_values_length] = False - - expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - return expanded_mask - - -def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, 1, tgt_length, src_length) - - def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it @@ -638,28 +609,6 @@ class BloomModel(BloomPreTrainedModel): def get_input_embeddings(self): return self.word_embeddings - def _prepare_attn_mask( - self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings @@ -746,11 +695,13 @@ class BloomModel(BloomPreTrainedModel): alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask( + causal_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), + inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) + causal_mask = causal_mask.bool() for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 4d2c96ecec..ccc253cc98 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -45,21 +46,6 @@ CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # contrastive loss function, adapted from # https://sachinruk.github.io/blog/2021-03-07-clip.html def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: @@ -665,24 +651,6 @@ class CLIPEncoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - class CLIPTextTransformer(nn.Module): def __init__(self, config: CLIPTextConfig): super().__init__() @@ -726,11 +694,13 @@ class CLIPTextTransformer(nn.Module): # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index a07ceedd72..4ef9395eb2 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -47,21 +48,6 @@ CLIPSEG_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: @@ -674,24 +660,6 @@ class CLIPSegEncoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - class CLIPSegTextTransformer(nn.Module): # Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer.__init__ with CLIP->CLIPSeg def __init__(self, config: CLIPSegTextConfig): @@ -737,11 +705,13 @@ class CLIPSegTextTransformer(nn.Module): # CLIPSeg's text model uses causal mask, prepare it here. # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index b964d72704..961febc480 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -416,22 +417,6 @@ class ConditionalDetrConvModel(nn.Module): return out, pos -# Copied from transformers.models.detr.modeling_detr._expand_mask with Detr->ConditionalDetr -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - -# Copied from transformers.models.detr.modeling_detr.DetrSinePositionEmbedding with Detr->ConditionalDetr class ConditionalDetrSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you @@ -1319,7 +1304,7 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel): # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1468,19 +1453,11 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): hidden_states = inputs_embeds input_shape = inputs_embeds.size()[:-1] - combined_attention_mask = None - - if attention_mask is not None and combined_attention_mask is not None: - # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] - ) - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - encoder_attention_mask = _expand_mask( - encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # optional intermediate hidden states @@ -1517,7 +1494,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - combined_attention_mask, + None, object_queries, query_position_embeddings, query_sine_embed, @@ -1529,7 +1506,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=None, object_queries=object_queries, query_position_embeddings=query_position_embeddings, query_sine_embed=query_sine_embed, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 03d1a3950f..2a93fc2aa0 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -39,6 +39,7 @@ from ...file_utils import ( replace_return_docstrings, requires_backends, ) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid @@ -463,21 +464,6 @@ class DeformableDetrConvModel(nn.Module): return out, pos -# Copied from transformers.models.detr.modeling_detr._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - class DeformableDetrSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you @@ -801,7 +787,7 @@ class DeformableDetrMultiheadAttention(nn.Module): # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, target_len, source_len): diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 525dfec2ab..cb3186c9dd 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -25,6 +25,7 @@ from torch import nn from ....activations import ACT2FN from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ....integrations.deepspeed import is_deepspeed_zero3_enabled +from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_utils import ( PreTrainedModel, @@ -57,21 +58,6 @@ MCTCT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class MCTCTConv1dSubsampler(nn.Module): """ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation @@ -587,7 +573,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 8d99c79ef8..609bb39df1 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -44,39 +45,6 @@ except ImportError: _CONFIG_FOR_DOC = "OpenLlamaConfig" -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->OpenLlama class OpenLlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -556,30 +524,6 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -636,8 +580,10 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + + input_shape = (batch_size, seq_length) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length ) hidden_states = inputs_embeds diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index ab397ebb99..db607b183b 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -34,6 +34,7 @@ from ...file_utils import ( is_vision_available, replace_return_docstrings, ) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid @@ -364,21 +365,6 @@ class DetaBackboneWithPositionalEncodings(nn.Module): return out, pos -# Copied from transformers.models.detr.modeling_detr._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrSinePositionEmbedding with DeformableDetr->Deta class DetaSinePositionEmbedding(nn.Module): """ @@ -687,7 +673,7 @@ class DetaMultiheadAttention(nn.Module): # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) if attention_mask is not None: if attention_mask.size() != (batch_size, 1, target_len, source_len): diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 08e11e06e6..9b73d0e651 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -408,20 +409,6 @@ class DetrConvModel(nn.Module): return out, pos -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - class DetrSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you @@ -1073,7 +1060,7 @@ class DetrEncoder(DetrPreTrainedModel): # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1220,15 +1207,15 @@ class DetrDecoder(DetrPreTrainedModel): if attention_mask is not None and combined_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - encoder_attention_mask = _expand_mask( - encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # optional intermediate hidden states diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 6f4f0838e6..511c55a848 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -89,148 +90,6 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter -class AttnMaskConverter: - """ - A utility attention mask class that allows: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, - key_value_length) that can be multiplied with attention scores - - Parameters: - is_causal (`bool`): - Whether the attention mask should be a uni-directional (causal) or bi-directional mask. - - sliding_window (`int`, *optional*): - Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. - """ - - def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): - self.is_causal = is_causal - self.sliding_window = sliding_window - - def to_causal_4d( - self, - batch_size: int, - query_length: int, - key_value_length: int, - dtype: torch.dtype = torch.float32, - device: Union[torch.device, "str"] = "cpu", - ) -> torch.Tensor: - """ - Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative - bias to upper right hand triangular matrix (causal mask). - """ - if not self.is_causal: - raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") - - # If shape is not cached, create a new causal mask and cache it - input_shape = (batch_size, query_length) - past_key_values_length = key_value_length - query_length - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if input_shape[-1] > 1 or self.sliding_window is not None: - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - - return causal_4d_mask - - def to_4d( - self, - attention_mask_2d: torch.Tensor, - query_length: int, - key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, - ) -> torch.Tensor: - """ - Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, - key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is - causal, a causal mask will be added. - """ - input_shape = (attention_mask_2d.shape[0], query_length) - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: - if key_value_length is None: - raise ValueError( - "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." - ) - - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=attention_mask_2d.device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - elif self.sliding_window is not None: - raise NotImplementedError("Sliding window is currently only implemented for causal masking") - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( - attention_mask_2d.device - ) - expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - - return expanded_4d_mask - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - - # add lower triangular sliding window mask if necessary - if sliding_window is not None: - diagonal = past_key_values_length - sliding_window + 1 - - context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) - mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities class FalconRotaryEmbedding(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. @@ -365,27 +224,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): self.sin_cached = self.sin_cached.type(dtype) -def _make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it - just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, - target_length, target_length+past_key_values_length]`. - """ - batch_size, target_length = input_ids_shape - - mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1) - # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op. - # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this - # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later. - past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device) - mask = torch.cat([past_mask, mask], dim=-1) - expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - return expanded_mask - - -def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor: +def _prepare_4d_attention_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor: """ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`. """ @@ -904,7 +743,7 @@ class FalconMLP(nn.Module): class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig, attn_mask_converter=None): + def __init__(self, config: FalconConfig): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -1153,8 +992,6 @@ class FalconModel(FalconPreTrainedModel): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - self.attn_mask_converter = AttnMaskConverter(is_causal=True) - # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -1260,16 +1097,10 @@ class FalconModel(FalconPreTrainedModel): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - key_value_length = seq_length + past_key_values_length # 4d mask is passed through the layers - if attention_mask is not None: - attention_mask = self.attn_mask_converter.to_4d( - attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype - ) - else: - attention_mask = self.attn_mask_converter.to_causal_4d( - batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 120576bfab..8909efa386 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...file_utils import ModelOutput +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -80,21 +81,6 @@ class GitVisionModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class GitEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" @@ -1276,9 +1262,9 @@ class GitModel(GitPreTrainedModel): if attention_mask is not None: # if the user provides an attention mask, we add it to the default one # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to( - embedding_output.device - ) + expanded_attn_mask = _prepare_4d_attention_mask( + attention_mask, embedding_output.dtype, tgt_len=input_shape[-1] + ).to(embedding_output.device) if past_key_values_length > 0: expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :] else: diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 680fe78f5c..3ed4157419 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -26,6 +26,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -48,21 +49,6 @@ GROUPVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: @@ -1056,24 +1042,6 @@ class GroupViTTextEncoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - # Copied from transformers.models.clip.modeling_clip.CLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder, CLIP_TEXT->GROUPVIT_TEXT class GroupViTTextTransformer(nn.Module): def __init__(self, config: GroupViTTextConfig): @@ -1118,11 +1086,13 @@ class GroupViTTextTransformer(nn.Module): # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index cb8945ea54..29deae594b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -29,6 +29,7 @@ from torch.nn import CrossEntropyLoss from ... import PreTrainedModel from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ModelOutput from ...modeling_utils import PretrainedConfig from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -444,38 +445,6 @@ class IdeficsDecoupledLinear(nn.Linear): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # this was adapted from LlamaRMSNorm class IdeficsRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -1122,30 +1091,6 @@ class IdeficsModel(IdeficsPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -1261,7 +1206,7 @@ class IdeficsModel(IdeficsPreTrainedModel): attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 5959c8538d..ebb4b2821c 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -21,6 +21,7 @@ import torch from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -230,39 +231,6 @@ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch. return -input.log_prob(target) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer class InformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -1184,7 +1152,7 @@ class InformerEncoder(InformerPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1274,29 +1242,6 @@ class InformerDecoder(InformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, attention_mask: Optional[torch.Tensor] = None, @@ -1380,14 +1325,16 @@ class InformerDecoder(InformerPreTrainedModel): # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index e757ef6a7b..12a0c2a991 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, @@ -74,22 +75,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): +def _prepare_4d_attention_mask_inverted(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -1838,7 +1824,7 @@ class LEDEncoder(LEDPreTrainedModel): # convert attention_mask to float if attention_mask is not None: # [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf" - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)[:, 0, 0, :] + attention_mask = _prepare_4d_attention_mask_inverted(attention_mask, inputs_embeds.dtype)[:, 0, 0, :] # get masking tensors is_index_masked = attention_mask < 0 @@ -2077,20 +2063,22 @@ class LEDDecoder(LEDPreTrainedModel): # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) + combined_attention_mask = _create_4d_causal_attention_mask( + input_shape, inputs_embeds.dtype, inputs_embeds.device, past_key_values_length=past_key_values_length + ) if attention_mask is not None and combined_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask_inverted( attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask_inverted( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a330ff62e5..4bf9da2639 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -66,163 +67,22 @@ def _get_unpad_data(attention_mask): def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): warnings.warn( - "Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask" + "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" ) - return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 ): warnings.warn( - "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask" + "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" ) - return AttnMaskConverter._make_causal_mask( + return AttentionMaskConverter._make_causal_mask( input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length ) -class AttnMaskConverter: - """ - A utility attention mask class that allows: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, - key_value_length) that can be multiplied with attention scores - - Parameters: - is_causal (`bool`): - Whether the attention mask should be a uni-directional (causal) or bi-directional mask. - - sliding_window (`int`, *optional*): - Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. - """ - - def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): - self.is_causal = is_causal - self.sliding_window = sliding_window - - def to_causal_4d( - self, - batch_size: int, - query_length: int, - key_value_length: int, - dtype: torch.dtype = torch.float32, - device: Union[torch.device, "str"] = "cpu", - ) -> torch.Tensor: - """ - Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative - bias to upper right hand triangular matrix (causal mask). - """ - if not self.is_causal: - raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") - - # If shape is not cached, create a new causal mask and cache it - input_shape = (batch_size, query_length) - past_key_values_length = key_value_length - query_length - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if input_shape[-1] > 1 or self.sliding_window is not None: - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - - return causal_4d_mask - - def to_4d( - self, - attention_mask_2d: torch.Tensor, - query_length: int, - key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, - ) -> torch.Tensor: - """ - Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, - key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is - causal, a causal mask will be added. - """ - input_shape = (attention_mask_2d.shape[0], query_length) - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: - if key_value_length is None: - raise ValueError( - "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." - ) - - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=attention_mask_2d.device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - elif self.sliding_window is not None: - raise NotImplementedError("Sliding window is currently only implemented for causal masking") - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( - attention_mask_2d.device - ) - expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - - return expanded_4d_mask - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - - # add lower triangular sliding window mask if necessary - if sliding_window is not None: - diagonal = past_key_values_length - sliding_window + 1 - - context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) - mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -933,8 +793,6 @@ class LlamaModel(LlamaPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.attn_mask_converter = AttnMaskConverter(is_causal=True) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -998,16 +856,10 @@ class LlamaModel(LlamaPreTrainedModel): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - key_value_length = seq_length + past_key_values_length # 4d mask is passed through the layers - if attention_mask is not None: - attention_mask = self.attn_mask_converter.to_4d( - attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype - ) - else: - attention_mask = self.attn_mask_converter.to_causal_4d( - batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) # embed positions hidden_states = inputs_embeds diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 07338731cd..ed17796d27 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -71,39 +72,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols @@ -790,7 +758,7 @@ class M2M100Encoder(M2M100PreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -991,25 +959,16 @@ class M2M100Decoder(M2M100PreTrainedModel): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + combined_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 68d7fe53bd..10a6d5f342 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -73,39 +74,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class MarianSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -760,7 +728,7 @@ class MarianEncoder(MarianPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -851,30 +819,6 @@ class MarianDecoder(MarianPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -979,14 +923,16 @@ class MarianDecoder(MarianPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 6b3e901d71..c62e3a04c1 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -245,21 +245,6 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.detr.modeling_detr._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py def sample_point( input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 0fda64fa49..6572f5cdde 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -25,6 +25,7 @@ from torch import Tensor, nn from ... import AutoBackbone from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -702,21 +703,6 @@ class DetrDecoderLayer(nn.Module): return outputs -# Copied from transformers.models.detr.modeling_detr._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - class DetrDecoder(nn.Module): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. @@ -817,18 +803,12 @@ class DetrDecoder(nn.Module): hidden_states = inputs_embeds input_shape = inputs_embeds.size()[:-1] - combined_attention_mask = None - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None @@ -851,7 +831,7 @@ class DetrDecoder(nn.Module): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - combined_attention_mask, + None, encoder_hidden_states, encoder_attention_mask, None, @@ -860,7 +840,7 @@ class DetrDecoder(nn.Module): else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=None, object_queries=object_queries, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_hidden_states, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index e3ec189012..96044ac1c2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -78,39 +79,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): return prev_output_tokens -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart class MBartLearnedPositionalEmbedding(nn.Embedding): """ @@ -798,7 +766,7 @@ class MBartEncoder(MBartPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -896,30 +864,6 @@ class MBartDecoder(MBartPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1026,14 +970,16 @@ class MBartDecoder(MBartPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 94ebc690aa..0eeff04185 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -30,6 +30,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -54,148 +55,6 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" -# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter -class AttnMaskConverter: - """ - A utility attention mask class that allows: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, - key_value_length) that can be multiplied with attention scores - - Parameters: - is_causal (`bool`): - Whether the attention mask should be a uni-directional (causal) or bi-directional mask. - - sliding_window (`int`, *optional*): - Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. - """ - - def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): - self.is_causal = is_causal - self.sliding_window = sliding_window - - def to_causal_4d( - self, - batch_size: int, - query_length: int, - key_value_length: int, - dtype: torch.dtype = torch.float32, - device: Union[torch.device, "str"] = "cpu", - ) -> torch.Tensor: - """ - Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative - bias to upper right hand triangular matrix (causal mask). - """ - if not self.is_causal: - raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") - - # If shape is not cached, create a new causal mask and cache it - input_shape = (batch_size, query_length) - past_key_values_length = key_value_length - query_length - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if input_shape[-1] > 1 or self.sliding_window is not None: - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - - return causal_4d_mask - - def to_4d( - self, - attention_mask_2d: torch.Tensor, - query_length: int, - key_value_length: Optional[int] = None, - dtype: torch.dtype = torch.float32, - ) -> torch.Tensor: - """ - Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, - key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is - causal, a causal mask will be added. - """ - input_shape = (attention_mask_2d.shape[0], query_length) - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: - if key_value_length is None: - raise ValueError( - "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." - ) - - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=attention_mask_2d.device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - elif self.sliding_window is not None: - raise NotImplementedError("Sliding window is currently only implemented for causal masking") - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( - attention_mask_2d.device - ) - expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - - return expanded_4d_mask - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - - # add lower triangular sliding window mask if necessary - if sliding_window is not None: - diagonal = past_key_values_length - sliding_window + 1 - - context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) - mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -209,21 +68,6 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral class MistralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -908,8 +752,6 @@ class MistralModel(MistralPreTrainedModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window) - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -978,7 +820,6 @@ class MistralModel(MistralPreTrainedModel): attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled - and past_key_values is not None ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -992,16 +833,14 @@ class MistralModel(MistralPreTrainedModel): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - key_value_length = seq_length + past_key_values_length # 4d mask is passed through the layers - if attention_mask is not None: - attention_mask = self.attn_mask_converter.to_4d( - attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype - ) - else: - attention_mask = self.attn_mask_converter.to_causal_4d( - batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 5fa6698d34..2477b90659 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -55,38 +56,6 @@ MPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) - # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround - seq_ids = torch.arange(target_length, device=device) - mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] - - if past_key_values_length > 0: - mask[:, :past_key_values_length] = False - - expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - return expanded_mask - - -# Copied from transformers.models.bloom.modeling_bloom._expand_mask -def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: - """ - Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. - """ - batch_size, src_length = mask.shape - tgt_length = tgt_length if tgt_length is not None else src_length - - expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) - return expanded_mask.expand(batch_size, 1, tgt_length, src_length) - - def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None): r""" Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it @@ -412,34 +381,6 @@ class MptModel(MptPreTrainedModel): def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None): return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device) - def _prepare_attn_mask( - self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - if input_shape[1] + past_key_values_length != attention_mask.shape[1]: - raise ValueError( - "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}." - ) - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - def set_input_embeddings(self, new_embeddings: torch.Tensor): self.wte = new_embeddings @@ -508,13 +449,12 @@ class MptModel(MptPreTrainedModel): alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, + causal_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) + causal_mask = causal_mask.bool() - for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)): + for block, layer_past in zip(self.blocks, past_key_values): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a1f2c58930..5d96359999 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...generation.configuration_utils import GenerationConfig from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from ...generation.stopping_criteria import StoppingCriteriaList +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -80,39 +81,6 @@ class MusicgenUnconditionalInput(ModelOutput): guidance_scale: float = None -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ @@ -706,30 +674,6 @@ class MusicgenDecoder(MusicgenPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) def forward( self, @@ -773,14 +717,16 @@ class MusicgenDecoder(MusicgenPreTrainedModel): if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 62e61f84ea..2539af3618 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -89,39 +90,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP class MvpLearnedPositionalEmbedding(nn.Embedding): """ @@ -918,7 +886,7 @@ class MvpEncoder(MvpPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1033,27 +1001,6 @@ class MvpDecoder(MvpPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1160,14 +1107,16 @@ class MvpDecoder(MvpPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 418f493acf..1c08a70875 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -75,39 +76,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ @@ -1125,7 +1093,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_router_probs = () if output_router_logits else None @@ -1342,25 +1310,16 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None and combined_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + combined_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index a7efc1a767..c62c517f9e 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -21,6 +21,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -63,38 +64,6 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class OPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -525,30 +494,6 @@ class OPTDecoder(OPTPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -643,7 +588,7 @@ class OPTDecoder(OPTPreTrainedModel): f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) - causal_attention_mask = self._prepare_decoder_attention_mask( + causal_attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) pos_embeds = self.embed_positions(attention_mask, past_key_values_length) diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 5325252980..0ee5372afb 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch OWLv2 model.""" - import warnings from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union @@ -25,6 +24,7 @@ import torch.utils.checkpoint from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -53,21 +53,6 @@ OWLV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2 def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) @@ -790,24 +775,6 @@ class Owlv2Encoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2 class Owlv2TextTransformer(nn.Module): def __init__(self, config: Owlv2TextConfig): @@ -845,11 +812,13 @@ class Owlv2TextTransformer(nn.Module): # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries # OWLV2's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 2880100d5c..77446f6512 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch OWL-ViT model.""" - import warnings from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union @@ -25,6 +24,7 @@ import torch.utils.checkpoint from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -54,21 +54,6 @@ OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) @@ -779,24 +764,6 @@ class OwlViTEncoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - class OwlViTTextTransformer(nn.Module): def __init__(self, config: OwlViTTextConfig): super().__init__() @@ -833,11 +800,13 @@ class OwlViTTextTransformer(nn.Module): # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries # OWLVIT's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index f5dddbe588..35eb1ffc1b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -72,39 +73,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus class PegasusSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -773,7 +741,7 @@ class PegasusEncoder(PegasusPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -871,30 +839,6 @@ class PegasusDecoder(PegasusPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def resize_position_embeddings(self, new_num_position_embeddings: int): """ Resizes position embeddings matrix of the model if `new_num_position_embeddings != @@ -1028,14 +972,16 @@ class PegasusDecoder(PegasusPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 60a9714fac..89c4f29cc0 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -90,39 +91,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class PegasusXSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -1138,55 +1106,6 @@ class PegasusXDecoder(PegasusXPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def resize_position_embeddings(self, new_num_position_embeddings: int): - """ - Resizes position embeddings matrix of the model if `new_num_position_embeddings != - config.max_position_embeddings`. - - Arguments: - new_num_position_embeddings (`int`): - The number of new position embeddings. If position embeddings are learned, increasing the size will add - newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If - position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will - add correct vectors at the end following the position encoding algorithm, whereas reducing the size - will remove vectors from the end. - """ - logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") - self.config.max_position_embeddings = new_num_position_embeddings - - self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model) - self.embed_positions.to(self.device) - - def get_position_embeddings(self) -> nn.Embedding: - """ - Returns the position embeddings matrix - """ - return self.embed_positions - def forward( self, input_ids=None, @@ -1278,14 +1197,16 @@ class PegasusXDecoder(PegasusXPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c3a25c1030..2ffcf87489 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -27,6 +27,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -38,39 +39,6 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "PersimmonConfig" -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon class PersimmonRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): @@ -563,30 +531,6 @@ class PersimmonModel(PersimmonPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) def forward( self, @@ -639,7 +583,7 @@ class PersimmonModel(PersimmonPreTrainedModel): attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index d828cd8e5b..79b5c09cba 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -77,39 +78,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): return prev_output_tokens -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart class PLBartLearnedPositionalEmbedding(nn.Embedding): """ @@ -776,7 +744,7 @@ class PLBartEncoder(PLBartPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -873,29 +841,6 @@ class PLBartDecoder(PLBartPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -1002,14 +947,16 @@ class PLBartDecoder(PLBartPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index a48c519166..29d8cf94e3 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -265,39 +266,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): """ Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that @@ -1002,7 +970,7 @@ class SeamlessM4TConformerAdapterLayer(nn.Module): hidden_states.device ) attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) - attention_mask = _expand_mask( + attention_mask = _prepare_4d_attention_mask( attention_mask, hidden_states.dtype, ) @@ -1819,7 +1787,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1934,30 +1902,6 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids: torch.LongTensor = None, @@ -2064,14 +2008,16 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length=past_key_values_length) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 24b97c028b..d3c48e6c91 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch Speech2Text model.""" - import math from typing import Optional, Tuple, Union @@ -23,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -62,39 +62,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class Conv1dSubsampler(nn.Module): """ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation @@ -788,7 +755,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -883,27 +850,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -1008,14 +954,16 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index c870d7a2f7..eebc9e57cf 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, logging, replace_return_docstrings @@ -42,39 +43,6 @@ SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2 class Speech2Text2SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -492,27 +460,6 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -617,14 +564,16 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index e9382b1beb..25d2b73c18 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -89,39 +90,6 @@ def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int = return shifted_input_values -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], @@ -1342,7 +1310,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -1539,30 +1507,6 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, hidden_states: Optional[torch.FloatTensor] = None, @@ -1647,14 +1591,16 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, hidden_states, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] + ) deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index e44b7cdd36..dbff6476a7 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -23,6 +23,7 @@ import torch from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -345,20 +346,6 @@ class TableTransformerConvModel(nn.Module): return out, pos -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None): - """ - Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`. - """ - batch_size, source_len = mask.size() - target_len = target_len if target_len is not None else source_len - - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - # Copied from transformers.models.detr.modeling_detr.DetrSinePositionEmbedding with Detr->TableTransformer class TableTransformerSinePositionEmbedding(nn.Module): """ @@ -966,7 +953,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1116,15 +1103,15 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): if attention_mask is not None and combined_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - combined_attention_mask = combined_attention_mask + _expand_mask( - attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - encoder_attention_mask = _expand_mask( - encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) # optional intermediate hidden states diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 30fce5fa6d..5f961bc8e1 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -22,6 +22,7 @@ import torch from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -225,39 +226,6 @@ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] return input_tensor.mean(dim=dim) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->TimeSeries class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -915,7 +883,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -999,29 +967,6 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, attention_mask: Optional[torch.Tensor] = None, @@ -1105,14 +1050,16 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index d9f1f7c915..3358983a68 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, logging, replace_return_docstrings @@ -42,39 +43,6 @@ TROCR_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR class TrOCRLearnedPositionalEmbedding(nn.Embedding): """ @@ -515,27 +483,6 @@ class TrOCRDecoder(TrOCRPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -655,14 +602,16 @@ class TrOCRDecoder(TrOCRPreTrainedModel): input_shape = input.shape - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index a58437fa5a..f347e900c3 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, ModelOutput, @@ -113,21 +114,6 @@ class VitsTextEncoderOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): in_act = input_a + input_b @@ -1150,7 +1136,7 @@ class VitsEncoder(nn.Module): # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) hidden_states = hidden_states * padding_mask diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index cd75e397fd..8df937e3e6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation.logits_process import WhisperTimeStampLogitsProcessor +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -84,39 +85,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], @@ -1003,28 +971,6 @@ class WhisperDecoder(WhisperPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -1120,7 +1066,7 @@ class WhisperDecoder(WhisperPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index ce5a88d7d5..754358b352 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -46,21 +47,6 @@ XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - # contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: @@ -729,24 +715,6 @@ class XCLIPEncoder(nn.Module): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - class XCLIPTextTransformer(nn.Module): def __init__(self, config: XCLIPTextConfig): super().__init__() @@ -787,11 +755,13 @@ class XCLIPTextTransformer(nn.Module): # X_CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index b6c2167652..bca75fccca 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging @@ -127,39 +128,6 @@ XGLM_INPUTS_DOCSTRING = r""" """ -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - class XGLMSinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length.""" @@ -548,27 +516,6 @@ class XGLMModel(XGLMPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -624,14 +571,16 @@ class XGLMModel(XGLMPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask( + attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 7a6867bba9..8e5bbfefe3 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1567,6 +1567,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1608,37 +1609,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -def _expand_mask( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None -): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) - - class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -2280,7 +2250,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -2369,25 +2339,6 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(self.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - def forward( self, input_ids=None, @@ -2491,12 +2442,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index df41c6c5f5..2402986900 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -39,149 +39,6 @@ if is_torch_available(): LlamaModel, LlamaTokenizer, ) - from transformers.models.llama.modeling_llama import AttnMaskConverter - - -@require_torch -class AttentionMaskTester(unittest.TestCase): - def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d): - mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len)) - mask_4d_values = mask_4d[:, 0][mask_indices] - is_inf = mask_4d_values == -float("inf") - is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min - assert torch.logical_or(is_inf, is_min).all() - - def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3): - mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long) - - if additional_mask is not None: - for bsz_idx, seq_idx in additional_mask: - mask_2d[bsz_idx, seq_idx] = 0 - - mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) - - assert mask_4d.shape == (bsz, 1, q_len, kv_len) - - context = mask_converter.sliding_window - if mask_converter.is_causal and context is None: - # k * (k+1) / 2 tokens are masked in triangualar masks - num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) - - if 0 not in mask_2d: - assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked - if 0 in mask_2d: - # at least causal mask + maybe more - assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked - self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) - elif not mask_converter.is_causal and context is None: - if 0 not in mask_2d: - assert (mask_4d != 0).sum().cpu().item() == 0 - if 0 in mask_2d: - self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) - elif mask_converter.is_causal and context is not None: - # k * (k+1) / 2 tokens are masked in triangualar masks - num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) - num_tokens_masked = bsz * num_tokens_masked - - if 0 not in mask_2d: - assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked - if 0 in mask_2d: - # at least causal mask + maybe more - assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked - self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) - - def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): - mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) - - if q_len == 1 and mask_converter.sliding_window is None: - # no causal mask if q_len is 1 - assert mask_4d is None - return - - context = mask_converter.sliding_window - if mask_converter.is_causal and context is None: - # k * (k+1) / 2 tokens are masked in triangualar masks - num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) - - assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked - elif not mask_converter.is_causal and context is None: - assert (mask_4d != 0).sum().cpu().item() == 0 - elif mask_converter.is_causal and context is not None: - # k * (k+1) / 2 tokens are masked in triangualar masks - num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) - num_tokens_masked = bsz * num_tokens_masked - - assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked - - def compute_num_context_mask(self, kv_len, context, q_len): - # This function computes the # of attention tokens that are added for - # the sliding window - c_mask_len = kv_len - context - num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 - cut_mask_len = max(c_mask_len - q_len, 0) - num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2 - return num_mask_triangle - num_cut_mask - - def test_2d_to_4d_causal(self): - mask_converter = AttnMaskConverter(is_causal=True) - - # auto-regressive use case - self.check_to_4d(mask_converter, q_len=1, kv_len=7) - # special auto-regressive case - self.check_to_4d(mask_converter, q_len=3, kv_len=7) - # non auto-regressive case - self.check_to_4d(mask_converter, q_len=7, kv_len=7) - - # same with extra attention masks - self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - - def test_2d_to_4d(self): - torch.ones((3, 7), device=torch_device, dtype=torch.long) - mask_converter = AttnMaskConverter(is_causal=False) - - # non auto-regressive case - self.check_to_4d(mask_converter, q_len=7, kv_len=7) - - # same with extra attention masks - self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - - def test_2d_to_4d_causal_sliding(self): - torch.ones((3, 7), device=torch_device, dtype=torch.long) - mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5) - - # auto-regressive use case - self.check_to_4d(mask_converter, q_len=1, kv_len=7) - # special auto-regressive case - self.check_to_4d(mask_converter, q_len=3, kv_len=7) - # non auto-regressive case - self.check_to_4d(mask_converter, q_len=7, kv_len=7) - - # same with extra attention masks - self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) - - def test_causal_mask(self): - mask_converter = AttnMaskConverter(is_causal=True) - - # auto-regressive use case - self.check_to_causal(mask_converter, q_len=1, kv_len=7) - # special auto-regressive case - self.check_to_causal(mask_converter, q_len=3, kv_len=7) - # non auto-regressive case - self.check_to_causal(mask_converter, q_len=7, kv_len=7) - - def test_causal_mask_sliding(self): - mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3) - - # auto-regressive use case - self.check_to_causal(mask_converter, q_len=1, kv_len=7) - # special auto-regressive case - self.check_to_causal(mask_converter, q_len=3, kv_len=7) - # non auto-regressive case - self.check_to_causal(mask_converter, q_len=7, kv_len=7) class LlamaModelTester: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2a6246f870..ffdb2ae7d0 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -48,6 +48,7 @@ from transformers.testing_utils import ( require_torch_multi_gpu, require_usr_bin_time, slow, + torch_device, ) from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -79,6 +80,7 @@ if is_torch_available(): T5Config, T5ForConditionalGeneration, ) + from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_utils import shard_checkpoint # Fake pretrained models for tests @@ -1184,3 +1186,143 @@ The commit description supports markdown synthax see: config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_config(config, trust_remote_code=True) self.assertEqual(new_model.__class__.__name__, "CustomModel") + + +@require_torch +class AttentionMaskTester(unittest.TestCase): + def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d): + mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len)) + mask_4d_values = mask_4d[:, 0][mask_indices] + is_inf = mask_4d_values == -float("inf") + is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min + assert torch.logical_or(is_inf, is_min).all() + + def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3): + mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long) + + if additional_mask is not None: + for bsz_idx, seq_idx in additional_mask: + mask_2d[bsz_idx, seq_idx] = 0 + + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + + assert mask_4d.shape == (bsz, 1, q_len, kv_len) + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif not mask_converter.is_causal and context is None: + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == 0 + if 0 in mask_2d: + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + + def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): + mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + + if q_len == 1 and mask_converter.sliding_window is None: + # no causal mask if q_len is 1 + assert mask_4d is None + return + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + elif not mask_converter.is_causal and context is None: + assert (mask_4d != 0).sum().cpu().item() == 0 + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + + def compute_num_context_mask(self, kv_len, context, q_len): + # This function computes the # of attention tokens that are added for + # the sliding window + c_mask_len = kv_len - context + num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 + cut_mask_len = max(c_mask_len - q_len, 0) + num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2 + return num_mask_triangle - num_cut_mask + + def test_2d_to_4d_causal(self): + mask_converter = AttentionMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d(self): + mask_converter = AttentionMaskConverter(is_causal=False) + + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d_causal_sliding(self): + mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=5) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_causal_mask(self): + mask_converter = AttentionMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_causal_mask_sliding(self): + mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=3) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) diff --git a/utils/check_inits.py b/utils/check_inits.py index 43361adbf8..383aa9c5db 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -329,6 +329,7 @@ IGNORE_SUBMODULES = [ "convert_pytorch_checkpoint_to_tf2", "modeling_flax_pytorch_utils", "models.esm.openfold_utils", + "modeling_attn_mask_utils", ]