Remove device parameter from create_extended_attention_mask_for_decoder (#16894)
This commit is contained in:
@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
|
||||
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
attention_mask, input_shape
|
||||
)
|
||||
|
||||
# define function for checkpointing
|
||||
|
||||
@@ -651,7 +651,13 @@ class ModuleUtilsMixin:
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
|
||||
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
|
||||
if device is not None:
|
||||
warnings.warn(
|
||||
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
)
|
||||
else:
|
||||
device = attention_mask.device
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
@@ -672,7 +678,9 @@ class ModuleUtilsMixin:
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
return extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
|
||||
def get_extended_attention_mask(
|
||||
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
@@ -681,12 +689,16 @@ class ModuleUtilsMixin:
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device: (`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
if not (attention_mask.dim() == 2 and self.config.is_decoder):
|
||||
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
||||
if device is not None:
|
||||
warnings.warn(
|
||||
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
||||
)
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
|
||||
@@ -982,7 +982,7 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -364,9 +364,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = None
|
||||
if not use_cache:
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -2112,9 +2112,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
to_mask = None
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"attention_type can either be original_full or block_sparse, but is {self.attention_type}"
|
||||
|
||||
@@ -1130,12 +1130,12 @@ class CanineModel(CaninePreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
molecule_attention_mask = self._downsample_attention_mask(
|
||||
attention_mask, downsampling_rate=self.config.downsampling_rate
|
||||
)
|
||||
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device
|
||||
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
||||
@@ -833,7 +833,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
|
||||
@@ -820,7 +820,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -882,7 +882,7 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -814,7 +814,7 @@ class IBertModel(IBertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
||||
@@ -1692,7 +1692,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
|
||||
:, 0, 0, :
|
||||
]
|
||||
|
||||
|
||||
@@ -940,7 +940,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -268,7 +268,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
|
||||
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
|
||||
)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
|
||||
@@ -875,9 +875,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, input_shape, self.device
|
||||
)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
||||
@@ -547,7 +547,7 @@ class MPNetModel(MPNetPreTrainedModel):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
@@ -624,7 +624,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
||||
@@ -952,7 +952,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -1078,7 +1078,7 @@ class RealmBertModel(RealmPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -857,7 +857,7 @@ class RemBertModel(RemBertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -117,7 +117,7 @@ class RetriBertModel(RetriBertPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
head_mask = [None] * sent_encoder.config.num_hidden_layers
|
||||
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
|
||||
attention_mask, input_shape, device
|
||||
attention_mask, input_shape
|
||||
)
|
||||
|
||||
# define function for checkpointing
|
||||
|
||||
@@ -817,7 +817,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -900,7 +900,7 @@ class RoFormerModel(RoFormerPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -710,7 +710,7 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -612,7 +612,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
|
||||
@@ -957,7 +957,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -954,7 +954,7 @@ class TapasModel(TapasPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -843,7 +843,7 @@ class ViltModel(ViltPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
|
||||
@@ -794,12 +794,12 @@ class VisualBertModel(VisualBertPreTrainedModel):
|
||||
if visual_embeds is not None:
|
||||
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
|
||||
combined_attention_mask, (batch_size, input_shape + visual_input_shape)
|
||||
)
|
||||
|
||||
else:
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, [batch_size, input_shape], device
|
||||
attention_mask, (batch_size, input_shape)
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
|
||||
@@ -788,7 +788,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -816,7 +816,7 @@ class YosoModel(YosoPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
||||
@@ -876,7 +876,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
Reference in New Issue
Block a user