946 lines
40 KiB
Python
946 lines
40 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Transformer XL model.
|
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
|
"""
|
|
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .configuration_transfo_xl import TransfoXLConfig
|
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
|
from .modeling_transfo_xl_utilities import LogUniformSampler, ProjectedAdaptiveLogSoftmax, sample_logits
|
|
from .modeling_utils import PreTrainedModel
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
|
|
}
|
|
|
|
|
|
def build_tf_to_pytorch_map(model, config):
|
|
""" A map of modules from TF to PyTorch.
|
|
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
|
"""
|
|
tf_to_pt_map = {}
|
|
|
|
if hasattr(model, "transformer"):
|
|
# We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
|
|
tf_to_pt_map.update(
|
|
{
|
|
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
|
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias,
|
|
}
|
|
)
|
|
for i, (out_l, proj_l, tie_proj) in enumerate(
|
|
zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
|
|
):
|
|
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
|
if config.tie_weight:
|
|
tf_to_pt_map.update({layer_str + "b": out_l.bias})
|
|
else:
|
|
raise NotImplementedError
|
|
# I don't think this is implemented in the TF code
|
|
tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias})
|
|
if not tie_proj:
|
|
tf_to_pt_map.update({layer_str + "proj": proj_l})
|
|
# Now load the rest of the transformer
|
|
model = model.transformer
|
|
|
|
# Embeddings
|
|
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
|
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
|
tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l})
|
|
|
|
# Transformer blocks
|
|
for i, b in enumerate(model.layers):
|
|
layer_str = "transformer/layer_%d/" % i
|
|
tf_to_pt_map.update(
|
|
{
|
|
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
|
|
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
|
|
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
|
|
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
|
|
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
|
|
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
|
|
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
|
|
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
|
|
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
|
|
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
|
|
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
|
}
|
|
)
|
|
|
|
# Relative positioning biases
|
|
if config.untie_r:
|
|
r_r_list = []
|
|
r_w_list = []
|
|
for b in model.layers:
|
|
r_r_list.append(b.dec_attn.r_r_bias)
|
|
r_w_list.append(b.dec_attn.r_w_bias)
|
|
else:
|
|
r_r_list = [model.r_r_bias]
|
|
r_w_list = [model.r_w_bias]
|
|
tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list})
|
|
return tf_to_pt_map
|
|
|
|
|
|
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
|
""" Load tf checkpoints in a pytorch model
|
|
"""
|
|
try:
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
logger.error(
|
|
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions."
|
|
)
|
|
raise
|
|
# Build TF to PyTorch weights loading map
|
|
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
|
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
tf_weights = {}
|
|
for name, shape in init_vars:
|
|
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
|
array = tf.train.load_variable(tf_path, name)
|
|
tf_weights[name] = array
|
|
|
|
for name, pointer in tf_to_pt_map.items():
|
|
assert name in tf_weights
|
|
array = tf_weights[name]
|
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
# which are not required for using pretrained model
|
|
if "kernel" in name or "proj" in name:
|
|
array = np.transpose(array)
|
|
if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
|
|
# Here we will split the TF weigths
|
|
assert len(pointer) == array.shape[0]
|
|
for i, p_i in enumerate(pointer):
|
|
arr_i = array[i, ...]
|
|
try:
|
|
assert p_i.shape == arr_i.shape
|
|
except AssertionError as e:
|
|
e.args += (p_i.shape, arr_i.shape)
|
|
raise
|
|
logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
|
|
p_i.data = torch.from_numpy(arr_i)
|
|
else:
|
|
try:
|
|
assert pointer.shape == array.shape
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
logger.info("Initialize PyTorch weight {}".format(name))
|
|
pointer.data = torch.from_numpy(array)
|
|
tf_weights.pop(name, None)
|
|
tf_weights.pop(name + "/Adam", None)
|
|
tf_weights.pop(name + "/Adam_1", None)
|
|
|
|
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
|
|
return model
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, demb):
|
|
super().__init__()
|
|
|
|
self.demb = demb
|
|
|
|
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
def forward(self, pos_seq, bsz=None):
|
|
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
|
|
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
|
|
|
|
if bsz is not None:
|
|
return pos_emb[:, None, :].expand(-1, bsz, -1)
|
|
else:
|
|
return pos_emb[:, None, :]
|
|
|
|
|
|
class PositionwiseFF(nn.Module):
|
|
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
|
|
super().__init__()
|
|
|
|
self.d_model = d_model
|
|
self.d_inner = d_inner
|
|
self.dropout = dropout
|
|
|
|
self.CoreNet = nn.Sequential(
|
|
nn.Linear(d_model, d_inner),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(d_inner, d_model),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
|
|
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
def forward(self, inp):
|
|
if self.pre_lnorm:
|
|
# layer normalization + positionwise feed-forward
|
|
core_out = self.CoreNet(self.layer_norm(inp))
|
|
|
|
# residual connection
|
|
output = core_out + inp
|
|
else:
|
|
# positionwise feed-forward
|
|
core_out = self.CoreNet(inp)
|
|
|
|
# residual connection + layer normalization
|
|
output = self.layer_norm(inp + core_out)
|
|
|
|
return output
|
|
|
|
|
|
class RelPartialLearnableMultiHeadAttn(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_head,
|
|
d_model,
|
|
d_head,
|
|
dropout,
|
|
dropatt=0,
|
|
tgt_len=None,
|
|
ext_len=None,
|
|
mem_len=None,
|
|
pre_lnorm=False,
|
|
r_r_bias=None,
|
|
r_w_bias=None,
|
|
output_attentions=False,
|
|
layer_norm_epsilon=1e-5,
|
|
):
|
|
super().__init__()
|
|
|
|
self.output_attentions = output_attentions
|
|
self.n_head = n_head
|
|
self.d_model = d_model
|
|
self.d_head = d_head
|
|
self.dropout = dropout
|
|
|
|
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
|
|
|
|
self.drop = nn.Dropout(dropout)
|
|
self.dropatt = nn.Dropout(dropatt)
|
|
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
|
|
|
|
self.scale = 1 / (d_head ** 0.5)
|
|
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
if r_r_bias is None or r_w_bias is None: # Biases are not shared
|
|
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
|
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
|
else:
|
|
self.r_r_bias = r_r_bias
|
|
self.r_w_bias = r_w_bias
|
|
|
|
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
|
|
|
|
def _rel_shift(self, x):
|
|
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
|
|
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
|
|
x_padded = torch.cat([zero_pad, x], dim=1)
|
|
|
|
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
|
|
x_padded = x_padded.view(*x_padded_shape)
|
|
|
|
x = x_padded[1:].view_as(x)
|
|
|
|
return x
|
|
|
|
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
|
|
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
|
|
|
|
if mems is not None:
|
|
cat = torch.cat([mems, w], 0)
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(cat))
|
|
else:
|
|
w_heads = self.qkv_net(cat)
|
|
r_head_k = self.r_net(r)
|
|
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
w_head_q = w_head_q[-qlen:]
|
|
else:
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(w))
|
|
else:
|
|
w_heads = self.qkv_net(w)
|
|
r_head_k = self.r_net(r)
|
|
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
|
|
klen = w_head_k.size(0)
|
|
|
|
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
|
|
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
|
|
|
# compute attention score
|
|
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
|
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
|
|
|
rr_head_q = w_head_q + self.r_r_bias
|
|
BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
|
|
BD = self._rel_shift(BD)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_score = AC + BD
|
|
attn_score.mul_(self.scale)
|
|
|
|
# compute attention probability
|
|
if attn_mask is not None and torch.sum(attn_mask).item():
|
|
attn_mask = attn_mask == 1 # Switch to bool
|
|
if attn_mask.dim() == 2:
|
|
if next(self.parameters()).dtype == torch.float16:
|
|
attn_score = (
|
|
attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
|
|
)
|
|
else:
|
|
attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
|
|
elif attn_mask.dim() == 3:
|
|
if next(self.parameters()).dtype == torch.float16:
|
|
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
|
|
else:
|
|
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_prob = F.softmax(attn_score, dim=1)
|
|
attn_prob = self.dropatt(attn_prob)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attn_prob = attn_prob * head_mask
|
|
|
|
# compute attention vector
|
|
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
|
|
|
|
# [qlen x bsz x n_head x d_head]
|
|
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
|
|
|
# linear projection
|
|
attn_out = self.o_net(attn_vec)
|
|
attn_out = self.drop(attn_out)
|
|
|
|
if self.pre_lnorm:
|
|
# residual connection
|
|
outputs = [w + attn_out]
|
|
else:
|
|
# residual connection + layer normalization
|
|
outputs = [self.layer_norm(w + attn_out)]
|
|
|
|
if self.output_attentions:
|
|
outputs.append(attn_prob)
|
|
|
|
return outputs
|
|
|
|
|
|
class RelPartialLearnableDecoderLayer(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs):
|
|
super().__init__()
|
|
|
|
self.dec_attn = RelPartialLearnableMultiHeadAttn(
|
|
n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs
|
|
)
|
|
self.pos_ff = PositionwiseFF(
|
|
d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
|
|
)
|
|
|
|
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
|
|
|
|
attn_outputs = self.dec_attn(dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask)
|
|
ff_output = self.pos_ff(attn_outputs[0])
|
|
|
|
outputs = [ff_output] + attn_outputs[1:]
|
|
|
|
return outputs
|
|
|
|
|
|
class AdaptiveEmbedding(nn.Module):
|
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
|
|
super().__init__()
|
|
|
|
self.n_token = n_token
|
|
self.d_embed = d_embed
|
|
|
|
self.cutoffs = cutoffs + [n_token]
|
|
self.div_val = div_val
|
|
self.d_proj = d_proj
|
|
|
|
self.emb_scale = d_proj ** 0.5
|
|
|
|
self.cutoff_ends = [0] + self.cutoffs
|
|
|
|
self.emb_layers = nn.ModuleList()
|
|
self.emb_projs = nn.ParameterList()
|
|
if div_val == 1:
|
|
self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
|
|
if d_proj != d_embed:
|
|
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
|
|
else:
|
|
for i in range(len(self.cutoffs)):
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
|
d_emb_i = d_embed // (div_val ** i)
|
|
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
|
|
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
|
|
|
|
def forward(self, inp):
|
|
if self.div_val == 1:
|
|
embed = self.emb_layers[0](inp)
|
|
if self.d_proj != self.d_embed:
|
|
embed = F.linear(embed, self.emb_projs[0])
|
|
else:
|
|
param = next(self.parameters())
|
|
inp_flat = inp.view(-1)
|
|
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
|
|
for i in range(len(self.cutoffs)):
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
|
|
|
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
|
|
indices_i = mask_i.nonzero().squeeze()
|
|
|
|
if indices_i.numel() == 0:
|
|
continue
|
|
|
|
inp_i = inp_flat.index_select(0, indices_i) - l_idx
|
|
emb_i = self.emb_layers[i](inp_i)
|
|
emb_i = F.linear(emb_i, self.emb_projs[i])
|
|
|
|
emb_flat.index_copy_(0, indices_i, emb_i)
|
|
|
|
embed_shape = inp.size() + (self.d_proj,)
|
|
embed = emb_flat.view(embed_shape)
|
|
|
|
embed.mul_(self.emb_scale)
|
|
|
|
return embed
|
|
|
|
|
|
class TransfoXLPreTrainedModel(PreTrainedModel):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for downloading and loading pretrained models.
|
|
"""
|
|
|
|
config_class = TransfoXLConfig
|
|
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
load_tf_weights = load_tf_weights_in_transfo_xl
|
|
base_model_prefix = "transformer"
|
|
|
|
def _init_weight(self, weight):
|
|
if self.config.init == "uniform":
|
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
|
elif self.config.init == "normal":
|
|
nn.init.normal_(weight, 0.0, self.config.init_std)
|
|
|
|
def _init_bias(self, bias):
|
|
nn.init.constant_(bias, 0.0)
|
|
|
|
def _init_weights(self, m):
|
|
""" Initialize the weights.
|
|
"""
|
|
classname = m.__class__.__name__
|
|
if classname.find("Linear") != -1:
|
|
if hasattr(m, "weight") and m.weight is not None:
|
|
self._init_weight(m.weight)
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
self._init_bias(m.bias)
|
|
elif classname.find("AdaptiveEmbedding") != -1:
|
|
if hasattr(m, "emb_projs"):
|
|
for i in range(len(m.emb_projs)):
|
|
if m.emb_projs[i] is not None:
|
|
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
|
|
elif classname.find("Embedding") != -1:
|
|
if hasattr(m, "weight"):
|
|
self._init_weight(m.weight)
|
|
elif classname.find("ProjectedAdaptiveLogSoftmax") != -1:
|
|
if hasattr(m, "cluster_weight") and m.cluster_weight is not None:
|
|
self._init_weight(m.cluster_weight)
|
|
if hasattr(m, "cluster_bias") and m.cluster_bias is not None:
|
|
self._init_bias(m.cluster_bias)
|
|
if hasattr(m, "out_projs"):
|
|
for i in range(len(m.out_projs)):
|
|
if m.out_projs[i] is not None:
|
|
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
|
|
elif classname.find("LayerNorm") != -1:
|
|
if hasattr(m, "weight"):
|
|
nn.init.normal_(m.weight, 1.0, self.config.init_std)
|
|
if hasattr(m, "bias") and m.bias is not None:
|
|
self._init_bias(m.bias)
|
|
else:
|
|
if hasattr(m, "r_emb"):
|
|
self._init_weight(m.r_emb)
|
|
if hasattr(m, "r_w_bias"):
|
|
self._init_weight(m.r_w_bias)
|
|
if hasattr(m, "r_r_bias"):
|
|
self._init_weight(m.r_r_bias)
|
|
if hasattr(m, "r_bias"):
|
|
self._init_bias(m.r_bias)
|
|
|
|
|
|
TRANSFO_XL_START_DOCSTRING = r"""
|
|
|
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
|
usage and behavior.
|
|
|
|
Parameters:
|
|
config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
|
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
|
"""
|
|
|
|
TRANSFO_XL_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary.
|
|
|
|
Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
|
|
See :func:`transformers.PreTrainedTokenizer.encode` and
|
|
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__
|
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
|
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
|
|
given to this model should not be passed as input ids as they have already been computed.
|
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
|
Mask to nullify selected heads of the self-attention modules.
|
|
Mask values selected in ``[0, 1]``:
|
|
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
|
input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
than the model's internal embedding lookup matrix.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
|
TRANSFO_XL_START_DOCSTRING,
|
|
)
|
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
|
|
self.n_token = config.vocab_size
|
|
|
|
self.d_embed = config.d_embed
|
|
self.d_model = config.d_model
|
|
self.n_head = config.n_head
|
|
self.d_head = config.d_head
|
|
|
|
self.word_emb = AdaptiveEmbedding(
|
|
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
|
)
|
|
|
|
self.drop = nn.Dropout(config.dropout)
|
|
|
|
self.n_layer = config.n_layer
|
|
|
|
self.tgt_len = config.tgt_len
|
|
self.mem_len = config.mem_len
|
|
self.ext_len = config.ext_len
|
|
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
|
|
|
self.attn_type = config.attn_type
|
|
|
|
if not config.untie_r:
|
|
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
|
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
|
|
|
self.layers = nn.ModuleList()
|
|
if config.attn_type == 0: # the default attention
|
|
for i in range(config.n_layer):
|
|
self.layers.append(
|
|
RelPartialLearnableDecoderLayer(
|
|
config.n_head,
|
|
config.d_model,
|
|
config.d_head,
|
|
config.d_inner,
|
|
config.dropout,
|
|
tgt_len=config.tgt_len,
|
|
ext_len=config.ext_len,
|
|
mem_len=config.mem_len,
|
|
dropatt=config.dropatt,
|
|
pre_lnorm=config.pre_lnorm,
|
|
r_w_bias=None if config.untie_r else self.r_w_bias,
|
|
r_r_bias=None if config.untie_r else self.r_r_bias,
|
|
output_attentions=self.output_attentions,
|
|
layer_norm_epsilon=config.layer_norm_epsilon,
|
|
)
|
|
)
|
|
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
|
|
raise NotImplementedError # Removed them to avoid maintaining dead code
|
|
|
|
self.same_length = config.same_length
|
|
self.clamp_len = config.clamp_len
|
|
|
|
if self.attn_type == 0: # default attention
|
|
self.pos_emb = PositionalEmbedding(self.d_model)
|
|
else: # learnable embeddings and absolute embeddings
|
|
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
|
|
|
self.init_weights()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.word_emb
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
self.word_emb = new_embeddings
|
|
|
|
def backward_compatible(self):
|
|
self.sample_softmax = -1
|
|
|
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
|
self.tgt_len = tgt_len
|
|
self.mem_len = mem_len
|
|
self.ext_len = ext_len
|
|
|
|
def _prune_heads(self, heads):
|
|
logger.info("Head pruning is not implemented for Transformer-XL model")
|
|
pass
|
|
|
|
def init_mems(self, bsz):
|
|
if self.mem_len > 0:
|
|
mems = []
|
|
param = next(self.parameters())
|
|
for i in range(self.n_layer):
|
|
empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device)
|
|
mems.append(empty)
|
|
|
|
return mems
|
|
else:
|
|
return None
|
|
|
|
def _update_mems(self, hids, mems, qlen, mlen):
|
|
# does not deal with None
|
|
if mems is None:
|
|
return None
|
|
|
|
# mems is not None
|
|
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
|
|
|
# There are `mlen + qlen` steps that can be cached into mems
|
|
# For the next step, the last `ext_len` of the `qlen` tokens
|
|
# will be used as the extended context. Hence, we only cache
|
|
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
|
# to `mlen + qlen - self.ext_len`.
|
|
with torch.no_grad():
|
|
new_mems = []
|
|
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
|
beg_idx = max(0, end_idx - self.mem_len)
|
|
for i in range(len(hids)):
|
|
|
|
cat = torch.cat([mems[i], hids[i]], dim=0)
|
|
new_mems.append(cat[beg_idx:end_idx].detach())
|
|
|
|
return new_mems
|
|
|
|
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
|
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None):
|
|
r"""
|
|
Return:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
|
|
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
|
Sequence of hidden-states at the last layer of the model.
|
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
|
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
|
|
should not be passed as input ids as they have already been computed.
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import TransfoXLTokenizer, TransfoXLModel
|
|
import torch
|
|
|
|
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
|
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
last_hidden_states, mems = outputs[:2]
|
|
|
|
"""
|
|
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
|
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
input_ids = input_ids.transpose(0, 1).contiguous()
|
|
qlen, bsz = input_ids.size()
|
|
elif inputs_embeds is not None:
|
|
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
|
|
qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
if mems is None:
|
|
mems = self.init_mems(bsz)
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
|
|
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
|
|
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
|
head_mask = head_mask.to(
|
|
dtype=next(self.parameters()).dtype
|
|
) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * self.n_layer
|
|
|
|
if inputs_embeds is not None:
|
|
word_emb = inputs_embeds
|
|
else:
|
|
word_emb = self.word_emb(input_ids)
|
|
|
|
mlen = mems[0].size(0) if mems is not None else 0
|
|
klen = mlen + qlen
|
|
if self.same_length:
|
|
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
|
|
mask_len = klen - self.mem_len
|
|
if mask_len > 0:
|
|
mask_shift_len = qlen - mask_len
|
|
else:
|
|
mask_shift_len = qlen
|
|
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
|
else:
|
|
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
|
|
:, :, None
|
|
]
|
|
|
|
hids = []
|
|
attentions = []
|
|
if self.attn_type == 0: # default
|
|
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
|
|
if self.clamp_len > 0:
|
|
pos_seq.clamp_(max=self.clamp_len)
|
|
pos_emb = self.pos_emb(pos_seq)
|
|
|
|
core_out = self.drop(word_emb)
|
|
pos_emb = self.drop(pos_emb)
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
hids.append(core_out)
|
|
mems_i = None if mems is None else mems[i]
|
|
layer_outputs = layer(
|
|
core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i, head_mask=head_mask[i]
|
|
)
|
|
core_out = layer_outputs[0]
|
|
if self.output_attentions:
|
|
attentions.append(layer_outputs[1])
|
|
else: # learnable embeddings and absolute embeddings
|
|
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
|
|
|
core_out = self.drop(core_out)
|
|
|
|
new_mems = self._update_mems(hids, mems, mlen, qlen)
|
|
|
|
# We transpose back here to shape [bsz, len, hidden_dim]
|
|
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
|
|
if self.output_hidden_states:
|
|
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
|
|
hids.append(core_out)
|
|
hids = list(t.transpose(0, 1).contiguous() for t in hids)
|
|
outputs.append(hids)
|
|
if self.output_attentions:
|
|
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
|
|
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
|
outputs.append(attentions)
|
|
|
|
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""The Transformer-XL Model with a language modeling head on top
|
|
(adaptive softmax with weights tied to the adaptive input embeddings)""",
|
|
TRANSFO_XL_START_DOCSTRING,
|
|
)
|
|
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.transformer = TransfoXLModel(config)
|
|
self.sample_softmax = config.sample_softmax
|
|
# use sampled softmax
|
|
if config.sample_softmax > 0:
|
|
self.out_layer = nn.Linear(config.d_model, config.vocab_size)
|
|
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
|
|
# use adaptive softmax (including standard softmax)
|
|
else:
|
|
self.crit = ProjectedAdaptiveLogSoftmax(
|
|
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
|
|
)
|
|
self.init_weights()
|
|
|
|
def tie_weights(self):
|
|
"""
|
|
Run this to be sure output and input (adaptive) softmax weights are tied
|
|
"""
|
|
# sampled softmax
|
|
if self.sample_softmax > 0:
|
|
if self.config.tie_weight:
|
|
self.out_layer.weight = self.transformer.word_emb.weight
|
|
# adaptive softmax (including standard softmax)
|
|
else:
|
|
if self.config.tie_weight:
|
|
for i in range(len(self.crit.out_layers)):
|
|
self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
|
|
if self.config.tie_projs:
|
|
for i, tie_proj in enumerate(self.config.tie_projs):
|
|
if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
|
|
if self.config.torchscript:
|
|
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
|
|
else:
|
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
|
|
elif tie_proj and self.config.div_val != 1:
|
|
if self.config.torchscript:
|
|
self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
|
|
else:
|
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
|
|
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
|
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
|
|
|
def init_mems(self, bsz):
|
|
return self.transformer.init_mems(bsz)
|
|
|
|
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
|
|
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None):
|
|
r"""
|
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
|
Labels for language modeling.
|
|
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
|
|
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
|
|
All labels set to ``-100`` are ignored (masked), the loss is only
|
|
computed for labels in ``[0, ..., config.vocab_size]``
|
|
|
|
Return:
|
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
|
|
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
|
|
Language modeling loss.
|
|
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
|
Contains pre-computed hidden-states (key and values in the attention blocks).
|
|
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
|
|
should not be passed as input ids as they have already been computed.
|
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
|
|
Examples::
|
|
|
|
from transformers import TransfoXLTokenizer, TransfoXLLMHeadModel
|
|
import torch
|
|
|
|
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
|
|
model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103')
|
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
|
outputs = model(input_ids)
|
|
prediction_scores, mems = outputs[:2]
|
|
|
|
"""
|
|
if input_ids is not None:
|
|
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
|
|
elif inputs_embeds is not None:
|
|
bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
transformer_outputs = self.transformer(input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds)
|
|
|
|
last_hidden = transformer_outputs[0]
|
|
pred_hid = last_hidden[:, -tgt_len:]
|
|
outputs = transformer_outputs[1:]
|
|
if self.sample_softmax > 0 and self.training:
|
|
assert self.config.tie_weight
|
|
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
|
|
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
|
|
outputs = [softmax_output] + outputs
|
|
if labels is not None:
|
|
# TODO: This is not implemented
|
|
raise NotImplementedError
|
|
else:
|
|
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
|
|
if labels is None:
|
|
softmax_output = softmax_output.view(bsz, tgt_len, -1)
|
|
outputs = [softmax_output] + outputs
|
|
else:
|
|
softmax_output = softmax_output.view(bsz, tgt_len)
|
|
outputs = [softmax_output, None] + outputs
|
|
|
|
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
|
|
|
|
def get_output_embeddings(self):
|
|
""" Double-check if you are using adaptive softmax.
|
|
"""
|
|
if self.sample_softmax > 0:
|
|
return self.out_layer
|
|
else:
|
|
return self.crit.out_layers[-1]
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
|
inputs = {"input_ids": input_ids}
|
|
|
|
# if past is defined in model kwargs then use it for faster decoding
|
|
if "past" in model_kwargs and model_kwargs["past"]:
|
|
inputs["mems"] = model_kwargs["past"]
|
|
|
|
return inputs
|