[cleanup] factor out get_head_mask, invert_attn_mask, get_exten… (#3806)
* Delete some copy pasted code
This commit is contained in:
@@ -17,10 +17,10 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
@@ -109,9 +109,102 @@ class ModuleUtilsMixin:
|
||||
module.mem_rss_pre_forward = 0
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self) -> device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||
"""type: torch.Tensor -> torch.Tensor"""
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
|
||||
# /transformer/transformer_layers.py#L270
|
||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
return encoder_extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
|
||||
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
|
||||
input_shape: tuple, shape of input_ids
|
||||
device: torch.Device, usually self.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor with dtype of attention_mask.dtype
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
# causal and attention masks must have same type with pytorch version < 1.3
|
||||
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask, num_hidden_layers):
|
||||
"""
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
attention_probs has shape bsz x n_heads x N x N
|
||||
Arguments:
|
||||
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
num_hidden_layers: int
|
||||
Returns:
|
||||
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
or list with [None] for each layer
|
||||
"""
|
||||
if head_mask is not None:
|
||||
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
||||
else:
|
||||
head_mask = [None] * num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
||||
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
||||
return head_mask
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
r""" Base class for all models.
|
||||
@@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
|
||||
if hasattr(self.config, "xla_device") and self.config.xla_device:
|
||||
if getattr(self.config, "xla_device", False):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
@@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ""
|
||||
model_to_load = model
|
||||
if not hasattr(model, cls.base_model_prefix) and any(
|
||||
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
|
||||
):
|
||||
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
||||
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
||||
start_prefix = cls.base_model_prefix + "."
|
||||
if hasattr(model, cls.base_model_prefix) and not any(
|
||||
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
|
||||
):
|
||||
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
||||
model_to_load = getattr(model, cls.base_model_prefix)
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
@@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
)
|
||||
model.tie_weights() # make sure token embedding weights are still tied if needed
|
||||
|
||||
# Set model in evaluation mode to desactivate DropOut modules by default
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
@@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
||||
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
@@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
||||
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
||||
|
||||
|
||||
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
|
||||
# Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < no_repeat_ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
@@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module):
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
activation_string = getattr(config, "summary_activation", None)
|
||||
self.activation = (
|
||||
get_activation(activation_string) if activation_string else Identity()
|
||||
) # type: typing.Callable
|
||||
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||
|
||||
Reference in New Issue
Block a user