[cleanup] factor out get_head_mask, invert_attn_mask, get_exten… (#3806)

* Delete some copy pasted code
This commit is contained in:
Sam Shleifer
2020-04-16 09:55:25 -04:00
committed by GitHub
parent d22894dfd4
commit dbd041243d
13 changed files with 132 additions and 379 deletions

View File

@@ -17,10 +17,10 @@
import logging
import os
import typing
from typing import Callable, Tuple
import torch
from torch import nn
from torch import Tensor, device, dtype, nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
@@ -109,9 +109,102 @@ class ModuleUtilsMixin:
module.mem_rss_pre_forward = 0
@property
def device(self):
def device(self) -> device:
return next(self.parameters()).device
@property
def dtype(self) -> dtype:
return next(self.parameters()).dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor"""
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
# /transformer/transformer_layers.py#L270
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
Arguments:
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
input_shape: tuple, shape of input_ids
device: torch.Device, usually self.device
Returns:
torch.Tensor with dtype of attention_mask.dtype
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def get_head_mask(self, head_mask, num_hidden_layers):
"""
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
attention_probs has shape bsz x n_heads x N x N
Arguments:
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
num_hidden_layers: int
Returns:
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
or list with [None] for each layer
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
return head_mask
class PreTrainedModel(nn.Module, ModuleUtilsMixin):
r""" Base class for all models.
@@ -340,7 +433,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
if hasattr(self.config, "xla_device") and self.config.xla_device:
if getattr(self.config, "xla_device", False):
import torch_xla.core.xla_model as xm
if xm.is_master_ordinal():
@@ -588,13 +681,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
@@ -627,7 +717,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
)
model.tie_weights() # make sure token embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
if output_loading_info:
@@ -944,7 +1034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
@@ -1446,12 +1536,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores[:, all_but_token_ids_mask] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search"""
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < no_repeat_ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
@@ -1883,9 +1973,7 @@ class SequenceSummary(nn.Module):
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation = (
get_activation(activation_string) if activation_string else Identity()
) # type: typing.Callable
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: