Remove device parameter from create_extended_attention_mask_for_decoder (#16894)

This commit is contained in:
Pavel Belevich
2022-05-03 11:06:11 -04:00
committed by GitHub
parent dd739f7045
commit 39f8eafc1b
31 changed files with 48 additions and 42 deletions

View File

@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)
# define function for checkpointing

View File

@@ -651,7 +651,13 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
@@ -672,7 +678,9 @@ class ModuleUtilsMixin:
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
@@ -681,12 +689,16 @@ class ModuleUtilsMixin:
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device: (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:

View File

@@ -982,7 +982,7 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -364,9 +364,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = None
if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -2112,9 +2112,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
to_mask = None
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
else:
raise ValueError(
f"attention_type can either be original_full or block_sparse, but is {self.attention_type}"

View File

@@ -1130,12 +1130,12 @@ class CanineModel(CaninePreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
molecule_attention_mask = self._downsample_attention_mask(
attention_mask, downsampling_rate=self.config.downsampling_rate
)
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
)
# Prepare head mask if needed

View File

@@ -833,7 +833,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.embeddings(

View File

@@ -820,7 +820,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -882,7 +882,7 @@ class ElectraModel(ElectraPreTrainedModel):
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -814,7 +814,7 @@ class IBertModel(IBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head

View File

@@ -1692,7 +1692,7 @@ class LongformerModel(LongformerPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, :
]

View File

@@ -940,7 +940,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -268,7 +268,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

View File

@@ -875,9 +875,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, self.device
)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head

View File

@@ -547,7 +547,7 @@ class MPNetModel(MPNetPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)

View File

@@ -624,7 +624,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head

View File

@@ -952,7 +952,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -1078,7 +1078,7 @@ class RealmBertModel(RealmPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -857,7 +857,7 @@ class RemBertModel(RemBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -117,7 +117,7 @@ class RetriBertModel(RetriBertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device
attention_mask, input_shape
)
# define function for checkpointing

View File

@@ -817,7 +817,7 @@ class RobertaModel(RobertaPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -900,7 +900,7 @@ class RoFormerModel(RoFormerPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -710,7 +710,7 @@ class SplinterModel(SplinterPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -612,7 +612,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N

View File

@@ -957,7 +957,7 @@ class T5Stack(T5PreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -954,7 +954,7 @@ class TapasModel(TapasPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -843,7 +843,7 @@ class ViltModel(ViltPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_outputs = self.encoder(
embedding_output,

View File

@@ -794,12 +794,12 @@ class VisualBertModel(VisualBertPreTrainedModel):
if visual_embeds is not None:
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
combined_attention_mask, [batch_size, input_shape + visual_input_shape], device
combined_attention_mask, (batch_size, input_shape + visual_input_shape)
)
else:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, [batch_size, input_shape], device
attention_mask, (batch_size, input_shape)
)
# Prepare head mask if needed

View File

@@ -788,7 +788,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -816,7 +816,7 @@ class YosoModel(YosoPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head

View File

@@ -876,7 +876,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]