From 3a9c88377fed787d97f98e1a59ed33ab413e3705 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 15 Jan 2019 12:59:38 +0100 Subject: [PATCH] adding Transformer XL --- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 125 ++ pytorch_pretrained_bert/modeling_openai.py | 17 + .../modeling_transfo_xl.py | 1432 +++++++++++++++++ .../modeling_transfo_xl_utilities.py | 314 ++++ .../tokenization_transfo_xl.py | 508 ++++++ 5 files changed, 2396 insertions(+) create mode 100755 pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py create mode 100644 pytorch_pretrained_bert/modeling_transfo_xl.py create mode 100644 pytorch_pretrained_bert/modeling_transfo_xl_utilities.py create mode 100644 pytorch_pretrained_bert/tokenization_transfo_xl.py diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py new file mode 100755 index 0000000000..03f71defd6 --- /dev/null +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright 2018 The HugginFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert OpenAI GPT checkpoint.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import argparse +import tensorflow as tf +import torch +import numpy as np + +from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME + + +def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, + transfo_xl_config_file, + pytorch_dump_folder_path): + config_path = os.path.abspath(transfo_xl_config_file) + tf_path = os.path.abspath(tf_checkpoint_path) + + print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + # Initialise PyTorch model + # Construct model + if transfo_xl_config_file == "": + config = TransfoXLConfig() + else: + config = TransfoXLConfig(transfo_xl_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = TransfoXLModel(config) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME + print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) + torch.save(model.state_dict(), pytorch_weights_dump_path) + print("Save configuration file to {}".format(pytorch_config_dump_path)) + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ## Required parameters + parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") + parser.add_argument("--transfo_xl_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") + parser.add_argument("--pytorch_dump_folder_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + args = parser.parse_args() + convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, + args.transfo_xl_config_file, + args.pytorch_dump_folder_path) diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index 40557c9626..c3cd165e68 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -1,3 +1,20 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT model.""" + import os import copy import json diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py new file mode 100644 index 0000000000..fccd4616e4 --- /dev/null +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -0,0 +1,1432 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Transformer XL model. + Directly adapted from https://github.com/kimiyoung/transformer-xl. + In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py +""" + +import os +import copy +import json +import math +import logging +import tarfile +import tempfile +import shutil +import collections + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from torch.nn.parameter import Parameter + +from .modeling import BertLayerNorm as LayerNorm +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz", +} +CONFIG_NAME = 'transfo_xl_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' + +class TransfoXLConfig(object): + """Configuration class to store the configuration of a `TransfoXLModel`. + """ + def __init__(self, + vocab_size_or_config_json_file=267735, + cutoffs=[20000, 40000, 200000], + d_model=410, + d_embed=410, + d_head=41, + d_inner=2100, + div_val=1.0, + pre_lnorm=False, + n_layer=16, + n_head=10, + tgt_len=150, + ext_len=0, + mem_len=150, + same_length=False, + attn_type=0, + clamp_len=-1, + sample_softmax=-1, + adaptive=True, + tied=True, + dropout=0.1, + dropatt=0.0, + init="normal", + init_range=0.01, + proj_init_std=0.01, + init_std=0.02): + """Constructs TransfoXLConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. + cutoffs: cutoffs for the adaptive softmax + d_model: Dimensionality of the model's hidden states. + d_embed: Dimensionality of the embeddings + d_head: Dimensionality of the model's heads. + div_val: divident value for adapative input and softmax + pre_lnorm: apply LayerNorm to the input instead of the output + d_inner: Inner dimension in FF + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + tgt_len: number of tokens to predict + ext_len: length of the extended context + mem_len: length of the retained previous heads + same_length: use the same attn length for all tokens + attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. + clamp_len: use the same pos embeddings after clamp_len + sample_softmax: number of samples in sampled softmax + adaptive: use adaptive softmax + tied: tie the word embedding and softmax weights + dropout: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + dropatt: The dropout ratio for the attention probabilities. + embd_pdrop: The dropout ratio for the embeddings. + init: parameter initializer to use + init_range: parameters initialized by U(-init_range, init_range). + proj_init_std: parameters initialized by N(0, init_std) + init_std: parameters initialized by N(0, init_std) + """ + if isinstance(vocab_size_or_config_json_file, str): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.cutoffs = [] + self.cutoffs.extend(cutoffs) + self.tie_projs = [False] + [True] * len(self.cutoffs) + self.d_model = d_model + self.d_embed = d_embed + self.d_head = d_head + self.d_inner = d_inner + self.div_val = div_val + self.pre_lnorm = pre_lnorm + self.n_layer = n_layer + self.n_head = n_head + self.tgt_len = tgt_len + self.ext_len = ext_len + self.mem_len = mem_len + self.same_length = same_length + self.attn_type = attn_type + self.clamp_len = clamp_len + self.sample_softmax = sample_softmax + self.adaptive = adaptive + self.tied = tied + self.dropout = dropout + self.dropatt = dropatt + self.init = init + self.init_range = init_range + self.proj_init_std = proj_init_std + self.init_std = init_std + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @property + def total_num_embeddings(self): + return self.vocab_size + self.n_special + self.n_ctx + + @classmethod + def from_dict(cls, json_object): + """Constructs a `TransfoXLConfig` from a Python dictionary of parameters.""" + config = TransfoXLConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `TransfoXLConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +class PositionalEmbedding(nn.Module): + def __init__(self, demb): + super(PositionalEmbedding, self).__init__() + + self.demb = demb + + inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + + if bsz is not None: + return pos_emb[:,None,:].expand(-1, bsz, -1) + else: + return pos_emb[:,None,:] + + +class PositionwiseFF(nn.Module): + def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): + super(PositionwiseFF, self).__init__() + + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + + self.CoreNet = nn.Sequential( + nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout), + ) + + self.layer_norm = nn.LayerNorm(d_model) + + self.pre_lnorm = pre_lnorm + + def forward(self, inp): + if self.pre_lnorm: + ##### layer normalization + positionwise feed-forward + core_out = self.CoreNet(self.layer_norm(inp)) + + ##### residual connection + output = core_out + inp + else: + ##### positionwise feed-forward + core_out = self.CoreNet(inp) + + ##### residual connection + layer normalization + output = self.layer_norm(inp + core_out) + + return output + +class MultiHeadAttn(nn.Module): + def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, + pre_lnorm=False): + super(MultiHeadAttn, self).__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + + self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) + self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) + + self.drop = nn.Dropout(dropout) + self.dropatt = nn.Dropout(dropatt) + self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) + + self.layer_norm = nn.LayerNorm(d_model) + + self.scale = 1 / (d_head ** 0.5) + + self.pre_lnorm = pre_lnorm + + def forward(self, h, attn_mask=None, mems=None): + ##### multihead attention + # [hlen x bsz x n_head x d_head] + + if mems is not None: + c = torch.cat([mems, h], 0) + else: + c = h + + if self.pre_lnorm: + ##### layer normalization + c = self.layer_norm(c) + + head_q = self.q_net(h) + head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) + + head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) + head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) + head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) + + # [qlen x klen x bsz x n_head] + attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) + attn_score.mul_(self.scale) + if attn_mask is not None and attn_mask.any().item(): + if attn_mask.dim() == 2: + attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) + elif attn_mask.dim() == 3: + attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) + + # [qlen x klen x bsz x n_head] + attn_prob = F.softmax(attn_score, dim=1) + attn_prob = self.dropatt(attn_prob) + + # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] + attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) + attn_vec = attn_vec.contiguous().view( + attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) + + ##### linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + ##### residual connection + output = h + attn_out + else: + ##### residual connection + layer normalization + output = self.layer_norm(h + attn_out) + + return output + +class RelMultiHeadAttn(nn.Module): + def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, + tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): + super(RelMultiHeadAttn, self).__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + + self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) + + self.drop = nn.Dropout(dropout) + self.dropatt = nn.Dropout(dropatt) + self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) + + self.layer_norm = nn.LayerNorm(d_model) + + self.scale = 1 / (d_head ** 0.5) + + self.pre_lnorm = pre_lnorm + + def _parallelogram_mask(self, h, w, left=False): + mask = torch.ones((h, w)).byte() + m = min(h, w) + mask[:m,:m] = torch.triu(mask[:m,:m]) + mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) + + if left: + return mask + else: + return mask.flip(0) + + def _shift(self, x, qlen, klen, mask, left=False): + if qlen > 1: + zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), + device=x.device, dtype=x.dtype) + else: + zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) + + if left: + mask = mask.flip(1) + x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) + else: + x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) + + x = x_padded.masked_select(mask[:,:,None,None]) \ + .view(qlen, klen, x.size(2), x.size(3)) + + return x + + def _rel_shift(self, x, zero_triu=False): + zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), + device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=1) + + x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) + + x = x_padded[1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(0), x.size(1))) + x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] + + return x + + def forward(self, w, r, attn_mask=None, mems=None): + raise NotImplementedError + +class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): + def __init__(self, *args, **kwargs): + super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) + + self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) + + def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): + qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) + + if mems is not None: + cat = torch.cat([mems, w], 0) + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(cat)) + else: + w_heads = self.qkv_net(cat) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + w_head_q = w_head_q[-qlen:] + else: + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(w)) + else: + w_heads = self.qkv_net(w) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + + klen = w_head_k.size(0) + + w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + + r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head + + #### compute attention score + rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head + AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head + + rr_head_q = w_head_q + r_r_bias + BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head + BD = self._rel_shift(BD) + + # [qlen x klen x bsz x n_head] + attn_score = AC + BD + attn_score.mul_(self.scale) + + #### compute attention probability + if attn_mask is not None and attn_mask.any().item(): + if attn_mask.dim() == 2: + attn_score = attn_score.float().masked_fill( + attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) + elif attn_mask.dim() == 3: + attn_score = attn_score.float().masked_fill( + attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) + + # [qlen x klen x bsz x n_head] + attn_prob = F.softmax(attn_score, dim=1) + attn_prob = self.dropatt(attn_prob) + + #### compute attention vector + attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) + + # [qlen x bsz x n_head x d_head] + attn_vec = attn_vec.contiguous().view( + attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) + + ##### linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + ##### residual connection + output = w + attn_out + else: + ##### residual connection + layer normalization + output = self.layer_norm(w + attn_out) + + return output + +class RelLearnableMultiHeadAttn(RelMultiHeadAttn): + def __init__(self, *args, **kwargs): + super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) + + def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): + # r_emb: [klen, n_head, d_head], used for term B + # r_w_bias: [n_head, d_head], used for term C + # r_bias: [klen, n_head], used for term D + + qlen, bsz = w.size(0), w.size(1) + + if mems is not None: + cat = torch.cat([mems, w], 0) + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(cat)) + else: + w_heads = self.qkv_net(cat) + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + + w_head_q = w_head_q[-qlen:] + else: + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(w)) + else: + w_heads = self.qkv_net(w) + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + + klen = w_head_k.size(0) + + w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) + w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) + w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) + + if klen > r_emb.size(0): + r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) + r_emb = torch.cat([r_emb_pad, r_emb], 0) + r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) + r_bias = torch.cat([r_bias_pad, r_bias], 0) + else: + r_emb = r_emb[-klen:] + r_bias = r_bias[-klen:] + + #### compute attention score + rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head + + AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head + B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head + D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head + BD = self._rel_shift(B_ + D_) + + # [qlen x klen x bsz x n_head] + attn_score = AC + BD + attn_score.mul_(self.scale) + + #### compute attention probability + if attn_mask is not None and attn_mask.any().item(): + if attn_mask.dim() == 2: + attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) + elif attn_mask.dim() == 3: + attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) + + # [qlen x klen x bsz x n_head] + attn_prob = F.softmax(attn_score, dim=1) + attn_prob = self.dropatt(attn_prob) + + #### compute attention vector + attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) + + # [qlen x bsz x n_head x d_head] + attn_vec = attn_vec.contiguous().view( + attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) + + ##### linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + ##### residual connection + output = w + attn_out + else: + ##### residual connection + layer normalization + output = self.layer_norm(w + attn_out) + + return output + +class DecoderLayer(nn.Module): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): + super(DecoderLayer, self).__init__() + + self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + pre_lnorm=kwargs.get('pre_lnorm')) + + def forward(self, dec_inp, dec_attn_mask=None, mems=None): + + output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, + mems=mems) + output = self.pos_ff(output) + + return output + +class RelLearnableDecoderLayer(nn.Module): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, + **kwargs): + super(RelLearnableDecoderLayer, self).__init__() + + self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, + **kwargs) + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + pre_lnorm=kwargs.get('pre_lnorm')) + + def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): + + output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, + attn_mask=dec_attn_mask, + mems=mems) + output = self.pos_ff(output) + + return output + +class RelPartialLearnableDecoderLayer(nn.Module): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, + **kwargs): + super(RelPartialLearnableDecoderLayer, self).__init__() + + self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, + d_head, dropout, **kwargs) + self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, + pre_lnorm=kwargs.get('pre_lnorm')) + + def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): + + output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, + attn_mask=dec_attn_mask, + mems=mems) + output = self.pos_ff(output) + + return output + + +class AdaptiveEmbedding(nn.Module): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, + sample_softmax=False): + super(AdaptiveEmbedding, self).__init__() + + self.n_token = n_token + self.d_embed = d_embed + + self.cutoffs = cutoffs + [n_token] + self.div_val = div_val + self.d_proj = d_proj + + self.emb_scale = d_proj ** 0.5 + + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = nn.ModuleList() + self.emb_projs = nn.ParameterList() + if div_val == 1: + self.emb_layers.append( + nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) + ) + if d_proj != d_embed: + self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] + d_emb_i = d_embed // (div_val ** i) + self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) + self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) + + def forward(self, inp): + if self.div_val == 1: + embed = self.emb_layers[0](inp) + if self.d_proj != self.d_embed: + embed = F.linear(embed, self.emb_projs[0]) + else: + param = next(self.parameters()) + inp_flat = inp.view(-1) + emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], + dtype=param.dtype, device=param.device) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + indices_i = mask_i.nonzero().squeeze() + + if indices_i.numel() == 0: + continue + + inp_i = inp_flat.index_select(0, indices_i) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = F.linear(emb_i, self.emb_projs[i]) + + emb_flat.index_copy_(0, indices_i, emb_i) + + embed = emb_flat.view(*inp.size(), self.d_proj) + + embed.mul_(self.emb_scale) + + return embed + +class MemTransformerLM(nn.Module): + def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, + dropout, dropatt, tie_weight=True, d_embed=None, + div_val=1, tie_projs=[False], pre_lnorm=False, + tgt_len=None, ext_len=None, mem_len=None, + cutoffs=[], adapt_inp=False, + same_length=False, attn_type=0, clamp_len=-1, + sample_softmax=-1): + super(MemTransformerLM, self).__init__() + self.n_token = n_token + + d_embed = d_model if d_embed is None else d_embed + self.d_embed = d_embed + self.d_model = d_model + self.n_head = n_head + self.d_head = d_head + + self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, + div_val=div_val) + + self.drop = nn.Dropout(dropout) + + self.n_layer = n_layer + + self.tgt_len = tgt_len + self.mem_len = mem_len + self.ext_len = ext_len + self.max_klen = tgt_len + ext_len + mem_len + + self.attn_type = attn_type + + self.layers = nn.ModuleList() + if attn_type == 0: # the default attention + for i in range(n_layer): + self.layers.append( + RelPartialLearnableDecoderLayer( + n_head, d_model, d_head, d_inner, dropout, + tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, + dropatt=dropatt, pre_lnorm=pre_lnorm) + ) + elif attn_type == 1: # learnable embeddings + for i in range(n_layer): + self.layers.append( + RelLearnableDecoderLayer( + n_head, d_model, d_head, d_inner, dropout, + tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, + dropatt=dropatt, pre_lnorm=pre_lnorm) + ) + elif attn_type in [2, 3]: # absolute embeddings + for i in range(n_layer): + self.layers.append( + DecoderLayer( + n_head, d_model, d_head, d_inner, dropout, + dropatt=dropatt, pre_lnorm=pre_lnorm) + ) + + self.sample_softmax = sample_softmax + # use sampled softmax + if sample_softmax > 0: + self.out_layer = nn.Linear(d_model, n_token) + if tie_weight: + self.out_layer.weight = self.word_emb.weight + self.tie_weight = tie_weight + self.sampler = LogUniformSampler(n_token, sample_softmax) + + # use adaptive softmax (including standard softmax) + else: + self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, + cutoffs, div_val=div_val) + + if tie_weight: + for i in range(len(self.crit.out_layers)): + self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight + + if tie_projs: + for i, tie_proj in enumerate(tie_projs): + if tie_proj and div_val == 1 and d_model != d_embed: + self.crit.out_projs[i] = self.word_emb.emb_projs[0] + elif tie_proj and div_val != 1: + self.crit.out_projs[i] = self.word_emb.emb_projs[i] + + self.same_length = same_length + self.clamp_len = clamp_len + + self._create_params() + + def backward_compatible(self): + self.sample_softmax = -1 + + def _create_params(self): + if self.attn_type == 0: # default attention + self.pos_emb = PositionalEmbedding(self.d_model) + self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) + self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) + elif self.attn_type == 1: # learnable + self.r_emb = nn.Parameter(torch.Tensor( + self.n_layer, self.max_klen, self.n_head, self.d_head)) + self.r_w_bias = nn.Parameter(torch.Tensor( + self.n_layer, self.n_head, self.d_head)) + self.r_bias = nn.Parameter(torch.Tensor( + self.n_layer, self.max_klen, self.n_head)) + elif self.attn_type == 2: # absolute standard + self.pos_emb = PositionalEmbedding(self.d_model) + elif self.attn_type == 3: # absolute deeper SA + self.r_emb = nn.Parameter(torch.Tensor( + self.n_layer, self.max_klen, self.n_head, self.d_head)) + + def reset_length(self, tgt_len, ext_len, mem_len): + self.tgt_len = tgt_len + self.mem_len = mem_len + self.ext_len = ext_len + + def init_mems(self): + if self.mem_len > 0: + mems = [] + param = next(self.parameters()) + for i in range(self.n_layer+1): + empty = torch.empty(0, dtype=param.dtype, device=param.device) + mems.append(empty) + + return mems + else: + return None + + def _update_mems(self, hids, mems, qlen, mlen): + # does not deal with None + if mems is None: return None + + # mems is not None + assert len(hids) == len(mems), 'len(hids) != len(mems)' + + # There are `mlen + qlen` steps that can be cached into mems + # For the next step, the last `ext_len` of the `qlen` tokens + # will be used as the extended context. Hence, we only cache + # the tokens from `mlen + qlen - self.ext_len - self.mem_len` + # to `mlen + qlen - self.ext_len`. + with torch.no_grad(): + new_mems = [] + end_idx = mlen + max(0, qlen - 0 - self.ext_len) + beg_idx = max(0, end_idx - self.mem_len) + for i in range(len(hids)): + + cat = torch.cat([mems[i], hids[i]], dim=0) + new_mems.append(cat[beg_idx:end_idx].detach()) + + return new_mems + + def _forward(self, dec_inp, mems=None): + qlen, bsz = dec_inp.size() + + word_emb = self.word_emb(dec_inp) + + mlen = mems[0].size(0) if mems is not None else 0 + klen = mlen + qlen + if self.same_length: + all_ones = word_emb.new_ones(qlen, klen) + mask_len = klen - self.mem_len + if mask_len > 0: + mask_shift_len = qlen - mask_len + else: + mask_shift_len = qlen + dec_attn_mask = (torch.triu(all_ones, 1+mlen) + + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 + else: + dec_attn_mask = torch.triu( + word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] + + hids = [] + if self.attn_type == 0: # default + pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, + dtype=word_emb.dtype) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq) + + core_out = self.drop(word_emb) + pos_emb = self.drop(pos_emb) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + core_out = layer(core_out, pos_emb, self.r_w_bias, + self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) + hids.append(core_out) + elif self.attn_type == 1: # learnable + core_out = self.drop(word_emb) + hids.append(core_out) + for i, layer in enumerate(self.layers): + if self.clamp_len > 0: + r_emb = self.r_emb[i][-self.clamp_len :] + r_bias = self.r_bias[i][-self.clamp_len :] + else: + r_emb, r_bias = self.r_emb[i], self.r_bias[i] + + mems_i = None if mems is None else mems[i] + core_out = layer(core_out, r_emb, self.r_w_bias[i], + r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) + hids.append(core_out) + elif self.attn_type == 2: # absolute + pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, + dtype=word_emb.dtype) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq) + + core_out = self.drop(word_emb + pos_emb[-qlen:]) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + if mems_i is not None and i == 0: + mems_i += pos_emb[:mlen] + core_out = layer(core_out, dec_attn_mask=dec_attn_mask, + mems=mems_i) + hids.append(core_out) + elif self.attn_type == 3: + core_out = self.drop(word_emb) + + hids.append(core_out) + for i, layer in enumerate(self.layers): + mems_i = None if mems is None else mems[i] + if mems_i is not None and mlen > 0: + cur_emb = self.r_emb[i][:-qlen] + cur_size = cur_emb.size(0) + if cur_size < mlen: + cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) + cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) + else: + cur_emb = cur_emb[-mlen:] + mems_i += cur_emb.view(mlen, 1, -1) + core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) + + core_out = layer(core_out, dec_attn_mask=dec_attn_mask, + mems=mems_i) + hids.append(core_out) + + core_out = self.drop(core_out) + + new_mems = self._update_mems(hids, mems, mlen, qlen) + + return core_out, new_mems + + def forward(self, data, target, *mems): + # nn.DataParallel does not allow size(0) tensors to be broadcasted. + # So, have to initialize size(0) mems inside the model forward. + # Moreover, have to return new_mems to allow nn.DataParallel to piece + # them together. + if not mems: mems = self.init_mems() + + tgt_len = target.size(0) + hidden, new_mems = self._forward(data, mems=mems) + + pred_hid = hidden[-tgt_len:] + if self.sample_softmax > 0 and self.training: + assert self.tie_weight + logit = sample_logits(self.word_emb, + self.out_layer.bias, target, pred_hid, self.sampler) + loss = -F.log_softmax(logit, -1)[:, :, 0] + else: + loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) + loss = loss.view(tgt_len, -1) + + if new_mems is None: + return [loss] + else: + return [loss] + new_mems + + +class TransfoXLPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(TransfoXLPreTrainedModel, self).__init__() + if not isinstance(config, TransfoXLConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. " + "To create a model from a pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_weight(weight): + if self.config.init == 'uniform': + nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) + elif self.config.init == 'normal': + nn.init.normal_(weight, 0.0, self.config.init_std) + + def init_bias(bias): + nn.init.constant_(bias, 0.0) + + def init_weights(self, m): + """ Initialize the weights. + """ + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + if hasattr(m, 'weight') and m.weight is not None: + self.init_weight(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + self.init_bias(m.bias) + elif classname.find('AdaptiveEmbedding') != -1: + if hasattr(m, 'emb_projs'): + for i in range(len(m.emb_projs)): + if m.emb_projs[i] is not None: + nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) + elif classname.find('Embedding') != -1: + if hasattr(m, 'weight'): + self.init_weight(m.weight) + elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: + if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: + self.init_weight(m.cluster_weight) + if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: + self.init_bias(m.cluster_bias) + if hasattr(m, 'out_projs'): + for i in range(len(m.out_projs)): + if m.out_projs[i] is not None: + nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) + elif classname.find('LayerNorm') != -1: + if hasattr(m, 'weight'): + nn.init.normal_(m.weight, 1.0, self.config.init_std) + if hasattr(m, 'bias') and m.bias is not None: + self.init_bias(m.bias) + elif classname.find('TransformerLM') != -1: + if hasattr(m, 'r_emb'): + self.init_weight(m.r_emb) + if hasattr(m, 'r_w_bias'): + self.init_weight(m.r_w_bias) + if hasattr(m, 'r_r_bias'): + self.init_weight(m.r_r_bias) + if hasattr(m, 'r_bias'): + self.init_bias(m.r_bias) + + def set_num_special_tokens(self, num_special_tokens): + pass + + @classmethod + def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, + *inputs, **kwargs): + """ + Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `transfo-xl` + - a path or url to a pretrained model archive containing: + . `transfo_xl_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + else: + archive_file = pretrained_model_name + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file)) + return None + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = TransfoXLConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + load(model.transformer if hasattr(model, 'transformer') else model, prefix='') + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + # Add additional embeddings for special tokens if needed + if num_special_tokens != config.n_special: + model.set_num_special_tokens(num_special_tokens) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + + + + + +################### + + + + +class TransfoXLLMHead(nn.Module): + """ Language Model Head for the transformer """ + + def __init__(self, model_embeddings_weights, config): + super(TransfoXLLMHead, self).__init__() + self.n_embd = config.n_embd + self.set_embeddings_weights(model_embeddings_weights) + + def set_embeddings_weights(self, model_embeddings_weights): + embed_shape = model_embeddings_weights.shape + self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) + self.decoder.weight = model_embeddings_weights # Tied weights + + def forward(self, hidden_state): + # Truncated Language modeling logits (we remove the last token) + # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) + lm_logits = self.decoder(hidden_state) + return lm_logits + + +class TransfoXLMultipleChoiceHead(nn.Module): + """ Classifier Head for the transformer """ + + def __init__(self, config): + super(TransfoXLMultipleChoiceHead, self).__init__() + self.n_embd = config.n_embd + # self.multiple_choice_token = multiple_choice_token + self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation + self.linear = nn.Linear(config.n_embd, 1) + + nn.init.normal_(self.linear.weight, std = 0.02) + nn.init.normal_(self.linear.bias, 0) + + def forward(self, hidden_states, multiple_choice_token_mask): + # Classification logits + # hidden_states = hidden_states.view(-1, self.n_embd) + # multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states) + multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1) + multiple_choice_h = multiple_choice_h.sum(dim=-2) + # flat = x[..., 0].contiguous().view(-1) + # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] + # multiple_choice_h = multiple_choice_h.view(-1, x.size(1), self.n_embd, 1) + # # This double transposition is there to replicate the behavior + # # of the noise_shape argument in the tensorflow + # # implementation. For more details, see + # # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11 + # multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) + # multiple_choice_h = multiple_choice_h.contiguous().view(-1, self.n_embd) + multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) + return multiple_choice_logits + + +class TransfoXLModel(TransfoXLPreTrainedModel): + """OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training"). + + The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix + to store the word, special ([SEP], [CLS]...) and position embeddings. + The embeddings are ordered as follow in the word embeddings matrice: + [0, ---------------------- + ... -> word embeddings + config.vocab_size - 1, ______________________ + config.vocab_size, + ... -> special embeddings + config.vocab_size + config.n_special - 1, ______________________ + config.vocab_size + config.n_special, + ... -> position embeddings + total_num_embeddings - 1] ______________________ + + where total_num_embeddings can be obtained as config.total_num_embeddings and is: + total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx + You should use the associate indices to index the embeddings. + + The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. + The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function. + + Params: + config: a TransfoXLConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] + were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ + `position_ids`: an optional torch.LongTensor with the same shape as input_ids + with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[. + `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids + You can use it to add a third embedding (the previous two being the word and position embeddings) + to each token in the sentence. + + Outputs: + `hidden_states`: the encoded-hidden-states at the top of the model + as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] + (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) + + Example usage: + ```python + # Already been converted into BPE token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + + config = modeling_transfo_xl.TransfoXLConfig() + + model = modeling_transfo_xl.TransfoXLModel(config) + hidden_states = model(input_ids) + ``` + """ + def __init__(self, config): + super(TransfoXLModel, self).__init__(config) + total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx + self.embed = nn.Embedding(total_embeddings_size, config.n_embd) + self.drop = nn.Dropout(config.embd_pdrop) + block = Block(config.n_ctx, config, scale=True) + self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) + + self.apply(self.init_weights) + # nn.init.normal_(self.embed.weight, std=0.02) + + def set_num_special_tokens(self, num_special_tokens): + " Update input embeddings with new embedding matrice " + # Update config + self.config.n_special = num_special_tokens + # # Build new embeddings and initialize + old_embed = self.embed + self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd) + # Initialize all new embeddings (in particular the special tokens) + self.init_weights(self.embed) + # Copy word and positional embeddings from the previous weights + self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] + self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :] + + def forward(self, input_ids, position_ids=None, token_type_ids=None): + if position_ids is None: + start = self.config.vocab_size + self.config.n_special + end = start + input_ids.size(-1) + position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_ids.size(-1)) + position_ids = position_ids.view(-1, position_ids.size(-1)) + + inputs_embeds = self.embed(input_ids) + position_embeds = self.embed(position_ids) + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + token_type_embeds = self.embed(token_type_ids) + else: + token_type_embeds = 0 + # Add the position information to the input embeddings + # h = e.sum(dim=2) + hidden_states = inputs_embeds + position_embeds + token_type_embeds + for block in self.h: + hidden_states = block(hidden_states) + return hidden_states.view(*input_shape, hidden_states.size(-1)) + +class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): + """OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training"). + + There are two main implementation differences between BERT and the OpenAI GPT: + - the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token + vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right) + - the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings. + The embeddings are ordered as follow in the word embeddings matrice: + [0, ---------------------- + ... -> word embeddings + config.vocab_size - 1, ______________________ + config.vocab_size, + ... -> special embeddings + config.vocab_size + config.n_special - 1, ______________________ + config.vocab_size + config.n_special, + ... -> position embeddings + total_num_embeddings - 1] ______________________ + + where total_num_embeddings can be obtained as config.total_num_embeddings and is: + total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx + You should use these indices to index the word, special and position embeddings. + + The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. + The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function. + + Params: + config: a TransfoXLConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] + were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ + `position_ids`: an optional torch.LongTensor with the same shape as input_ids + with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[. + `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids + You can use it to add a third embedding (the previous two being the word and position embeddings) + to each token in the sentence. + `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `lm_labels` is not `None`: + Outputs the language modeling loss. + else: + `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings] + (or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids) + + Example usage: + ```python + # Already been converted into BPE token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + + config = modeling_transfo_xl.TransfoXLConfig() + + model = modeling_transfo_xl.TransfoXLLMHeadModel(config) + lm_logits = model(input_ids) + ``` + """ + def __init__(self, config): + super(TransfoXLLMHeadModel, self).__init__(config) + self.transformer = TransfoXLModel(config) + self.lm_head = TransfoXLLMHead(self.transformer.embed.weight, config) + self.apply(self.init_weights) + + def set_num_special_tokens(self, num_special_tokens): + " Update input and output embeddings with new embedding matrice " + self.transformer.set_num_special_tokens(num_special_tokens) + self.lm_head.set_embeddings_weights(self.transformer.embed.weight) + + def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): + hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + lm_logits = self.lm_head(hidden_states) + if lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) + return loss + return lm_logits + +class TransfoXLDoubleHeadsModel(TransfoXLPreTrainedModel): + """OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training"). + + There are two main implementation differences between BERT and the OpenAI GPT: + - the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token + vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right) + - the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings. + The embeddings are ordered as follow in the word embeddings matrice: + [0, ---------------------- + ... -> word embeddings + config.vocab_size - 1, ______________________ + config.vocab_size, + ... -> special embeddings + config.vocab_size + config.n_special - 1, ______________________ + config.vocab_size + config.n_special, + ... -> position embeddings + total_num_embeddings - 1] ______________________ + + where total_num_embeddings can be obtained as config.total_num_embeddings and is: + total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx + You should use these indices to index the word, special and position embeddings. + + The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them. + The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function. + + Params: + config: a TransfoXLConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word BPE token indices selected in the range [0, config.vocab_size[ + `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. + `position_ids`: an optional torch.LongTensor with the same shape as input_ids + with the position indices (selected in the range [config.vocab_size + config.n_special, + config.vocab_size + config.n_special + config.n_ctx - 1[. + `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids + You can use it to add a third embedding (the previous two being the word and position embeddings) + to each token in the sentence. + `lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with indices selected in [-1, 0, ..., total_num_embeddings]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., total_num_embeddings] + `multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `lm_labels` and `multiple_choice_labels` are not `None`: + Outputs a tuple of losses with the language modeling loss and the multiple choice loss. + else: a tuple with + `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings] + `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices] + + Example usage: + ```python + # Already been converted into BPE token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling_transfo_xl.TransfoXLConfig() + + model = modeling_transfo_xl.TransfoXLLMHeadModel(config) + lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask) + ``` + """ + def __init__(self, config): + super(TransfoXLDoubleHeadsModel, self).__init__(config) + self.transformer = TransfoXLModel(config) + self.lm_head = TransfoXLLMHead(self.transformer.embed.weight, config) + self.multiple_choice_head = TransfoXLMultipleChoiceHead(config) + self.apply(self.init_weights) + + def set_num_special_tokens(self, num_special_tokens): + " Update input and output embeddings with new embedding matrice " + self.transformer.set_num_special_tokens(num_special_tokens) + self.lm_head.set_embeddings_weights(self.transformer.embed.weight) + + def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None, + lm_labels=None, multiple_choice_labels=None): + hidden_states = self.transformer(input_ids, position_ids, token_type_ids) + lm_logits = self.lm_head(hidden_states) + multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask) + losses = [] + if lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))) + if multiple_choice_labels is not None: + loss_fct = CrossEntropyLoss() + losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1))) + if losses: + return losses + return lm_logits, multiple_choice_logits diff --git a/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py b/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py new file mode 100644 index 0000000000..a9ead38faf --- /dev/null +++ b/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py @@ -0,0 +1,314 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Utilities for PyTorch Transformer XL model. + Directly adapted from https://github.com/kimiyoung/transformer-xl. +""" + +from collections import defaultdict + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) +# CUDA_MINOR = int(torch.version.cuda.split('.')[1]) + +class ProjectedAdaptiveLogSoftmax(nn.Module): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, + keep_order=False): + super(ProjectedAdaptiveLogSoftmax, self).__init__() + + self.n_token = n_token + self.d_embed = d_embed + self.d_proj = d_proj + + self.cutoffs = cutoffs + [n_token] + self.cutoff_ends = [0] + self.cutoffs + self.div_val = div_val + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + if self.n_clusters > 0: + self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) + self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) + + self.out_layers = nn.ModuleList() + self.out_projs = nn.ParameterList() + + if div_val == 1: + for i in range(len(self.cutoffs)): + if d_proj != d_embed: + self.out_projs.append( + nn.Parameter(torch.Tensor(d_proj, d_embed)) + ) + else: + self.out_projs.append(None) + + self.out_layers.append(nn.Linear(d_embed, n_token)) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] + d_emb_i = d_embed // (div_val ** i) + + self.out_projs.append( + nn.Parameter(torch.Tensor(d_proj, d_emb_i)) + ) + + self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) + + self.keep_order = keep_order + + def _compute_logit(self, hidden, weight, bias, proj): + if proj is None: + logit = F.linear(hidden, weight, bias=bias) + else: + # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: + proj_hid = F.linear(hidden, proj.t().contiguous()) + logit = F.linear(proj_hid, weight, bias=bias) + # else: + # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) + # if bias is not None: + # logit = logit + bias + + return logit + + def forward(self, hidden, target, keep_order=False): + ''' + hidden :: [len*bsz x d_proj] + target :: [len*bsz] + ''' + + if hidden.size(0) != target.size(0): + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + + if self.n_clusters == 0: + logit = self._compute_logit(hidden, self.out_layers[0].weight, + self.out_layers[0].bias, self.out_projs[0]) + nll = -F.log_softmax(logit, dim=-1) \ + .gather(1, target.unsqueeze(1)).squeeze(1) + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers[0].weight[l_idx:r_idx] + bias_i = self.out_layers[0].bias[l_idx:r_idx] + else: + weight_i = self.out_layers[i].weight + bias_i = self.out_layers[i].bias + + if i == 0: + weight_i = torch.cat( + [weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat( + [bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] + + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + head_logprob = F.log_softmax(head_logit, dim=1) + + nll = torch.zeros_like(target, + dtype=hidden.dtype, device=hidden.device) + + offset = 0 + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + + mask_i = (target >= l_idx) & (target < r_idx) + indices_i = mask_i.nonzero().squeeze() + + if indices_i.numel() == 0: + continue + + target_i = target.index_select(0, indices_i) - l_idx + head_logprob_i = head_logprob.index_select(0, indices_i) + + if i == 0: + logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] + + hidden_i = hidden.index_select(0, indices_i) + + tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) + tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) + + logprob_i = head_logprob_i[:, -i] \ + + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) + + if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: + nll.index_copy_(0, indices_i, -logprob_i) + else: + nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) + + offset += logprob_i.size(0) + + return nll + +class LogUniformSampler(object): + def __init__(self, range_max, n_sample): + """ + Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py + `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` + + expected count can be approximated by 1 - (1 - p)^n + and we use a numerically stable version -expm1(num_tries * log1p(-p)) + + Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run + """ + with torch.no_grad(): + self.range_max = range_max + log_indices = torch.arange(1., range_max+2., 1.).log_() + self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] + # print('P', self.dist.numpy().tolist()[-30:]) + + self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() + + self.n_sample = n_sample + + def sample(self, labels): + """ + labels: [b1, b2] + Return + true_log_probs: [b1, b2] + samp_log_probs: [n_sample] + neg_samples: [n_sample] + """ + + # neg_samples = torch.empty(0).long() + n_sample = self.n_sample + n_tries = 2 * n_sample + + with torch.no_grad(): + neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() + device = labels.device + neg_samples = neg_samples.to(device) + true_log_probs = self.log_q[labels].to(device) + samp_log_probs = self.log_q[neg_samples].to(device) + return true_log_probs, samp_log_probs, neg_samples + +def sample_logits(embedding, bias, labels, inputs, sampler): + """ + embedding: an nn.Embedding layer + bias: [n_vocab] + labels: [b1, b2] + inputs: [b1, b2, n_emb] + sampler: you may use a LogUniformSampler + Return + logits: [b1, b2, 1 + n_sample] + """ + true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) + n_sample = neg_samples.size(0) + b1, b2 = labels.size(0), labels.size(1) + all_ids = torch.cat([labels.view(-1), neg_samples]) + all_w = embedding(all_ids) + true_w = all_w[: -n_sample].view(b1, b2, -1) + sample_w = all_w[- n_sample:].view(n_sample, -1) + + all_b = bias[all_ids] + true_b = all_b[: -n_sample].view(b1, b2) + sample_b = all_b[- n_sample:] + + hit = (labels[:, :, None] == neg_samples).detach() + + true_logits = torch.einsum('ijk,ijk->ij', + [true_w, inputs]) + true_b - true_log_probs + sample_logits = torch.einsum('lk,ijk->ijl', + [sample_w, inputs]) + sample_b - samp_log_probs + sample_logits.masked_fill_(hit, -1e30) + logits = torch.cat([true_logits[:, :, None], sample_logits], -1) + + return logits + + +# class LogUniformSampler(object): +# def __init__(self, range_max, unique=False): +# """ +# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py +# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` +# """ +# self.range_max = range_max +# log_indices = torch.arange(1., range_max+2., 1.).log_() +# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] + +# self.unique = unique + +# if self.unique: +# self.exclude_mask = torch.ByteTensor(range_max).fill_(0) + +# def sample(self, n_sample, labels): +# pos_sample, new_labels = labels.unique(return_inverse=True) +# n_pos_sample = pos_sample.size(0) +# n_neg_sample = n_sample - n_pos_sample + +# if self.unique: +# self.exclude_mask.index_fill_(0, pos_sample, 1) +# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) +# self.exclude_mask.index_fill_(0, pos_sample, 0) +# else: +# sample_dist = self.dist + +# neg_sample = torch.multinomial(sample_dist, n_neg_sample) + +# sample = torch.cat([pos_sample, neg_sample]) +# sample_prob = self.dist[sample] + +# return new_labels, sample, sample_prob + + +if __name__ == '__main__': + S, B = 3, 4 + n_vocab = 10000 + n_sample = 5 + H = 32 + + labels = torch.LongTensor(S, B).random_(0, n_vocab) + + # sampler = LogUniformSampler(n_vocab, unique=False) + # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) + + sampler = LogUniformSampler(n_vocab, unique=True) + # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) + + # print('true_probs', true_probs.numpy().tolist()) + # print('samp_probs', samp_probs.numpy().tolist()) + # print('neg_samples', neg_samples.numpy().tolist()) + + # print('sum', torch.sum(sampler.dist).item()) + + # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() + + embedding = nn.Embedding(n_vocab, H) + bias = torch.zeros(n_vocab) + inputs = torch.Tensor(S, B, H).normal_() + + logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) + print('logits', logits.detach().numpy().tolist()) + print('logits shape', logits.size()) + print('out_labels', out_labels.detach().numpy().tolist()) + print('out_labels shape', out_labels.size()) + diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py new file mode 100644 index 0000000000..1d278abcb2 --- /dev/null +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -0,0 +1,508 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization classes for Transformer XL model. + Directly adapted from https://github.com/kimiyoung/transformer-xl. +""" + +import os +import re +import json +from tqdm import tqdm +import logging +import pickle +from collections import Counter, OrderedDict + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'openai-gpt': 512, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' + +class TransfoXLTokenizer(object): + """ + Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a TransfoXLTokenizer. + Download and cache the vocabulary if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + vocab_file, merges_file)) + return None + if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + logger.info("loading merges file {}".format(merges_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + logger.info("loading merges file {} from cache at {}".format( + merges_file, resolved_merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) + return tokenizer + + def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, + delimiter=None, vocab_file=None): + self.counter = Counter() + self.special = special + self.min_freq = min_freq + self.max_size = max_size + self.lower_case = lower_case + self.delimiter = delimiter + self.vocab_file = vocab_file + + def count_file(self, path, verbose=False, add_eos=False): + if verbose: print('counting file {} ...'.format(path)) + assert os.path.exists(path) + + sents = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos) + self.counter.update(symbols) + sents.append(symbols) + + return sents + + def count_sents(self, sents, verbose=False): + """ + sents : a list of sentences, each a list of tokenized symbols + """ + if verbose: print('counting {} sents ...'.format(len(sents))) + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + self.counter.update(symbols) + + def _build_from_file(self, vocab_file): + self.idx2sym = [] + self.sym2idx = OrderedDict() + + with open(vocab_file, 'r', encoding='utf-8') as f: + for line in f: + symb = line.strip().split()[0] + self.add_symbol(symb) + self.unk_idx = self.sym2idx[''] + + def build_vocab(self): + if self.vocab_file: + print('building vocab from {}'.format(self.vocab_file)) + self._build_from_file(self.vocab_file) + print('final vocab size {}'.format(len(self))) + else: + print('building vocab with min_freq={}, max_size={}'.format( + self.min_freq, self.max_size)) + self.idx2sym = [] + self.sym2idx = OrderedDict() + + for sym in self.special: + self.add_special(sym) + + for sym, cnt in self.counter.most_common(self.max_size): + if cnt < self.min_freq: break + self.add_symbol(sym) + + print('final vocab size {} from {} unique tokens'.format( + len(self), len(self.counter))) + + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, + add_double_eos=False): + if verbose: print('encoding file {} ...'.format(path)) + assert os.path.exists(path) + encoded = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos, + add_double_eos=add_double_eos) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def encode_sents(self, sents, ordered=False, verbose=False): + if verbose: print('encoding {} sents ...'.format(len(sents))) + encoded = [] + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def add_special(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) + + def add_symbol(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + + def get_sym(self, idx): + assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) + return self.idx2sym[idx] + + def get_idx(self, sym): + if sym in self.sym2idx: + return self.sym2idx[sym] + else: + # print('encounter unk {}'.format(sym)) + assert '' not in sym + assert hasattr(self, 'unk_idx') + return self.sym2idx.get(sym, self.unk_idx) + + def convert_ids_to_tokens(self, indices): + """Converts a sequence of indices in symbols using the vocab.""" + return [self.get_sym(idx) for idx in indices] + + def convert_tokens_to_ids(self, symbols): + """Converts a sequence of symbols into ids using the vocab.""" + return [self.get_idx(sym) for sym in symbols] + + def convert_to_tensor(self, symbols): + return torch.LongTensor(self.convert_tokens_to_ids(symbols)) + + def decode(self, indices, exclude=None): + """Converts a sequence of indices in a string.""" + if exclude is None: + return ' '.join([self.get_sym(idx) for idx in indices]) + else: + return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) + + def __len__(self): + return len(self.idx2sym) + + def tokenize(self, line, add_eos=False, add_double_eos=False): + line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() + + # empty delimiter '' will evaluate False + if self.delimiter == '': + symbols = line + else: + symbols = line.split(self.delimiter) + + if add_double_eos: # lm1b + return [''] + symbols + [''] + elif add_eos: + return symbols + [''] + else: + return symbols + + +class LMOrderedIterator(object): + def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): + """ + data -- LongTensor -- the LongTensor is strictly ordered + """ + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + + # Work out how cleanly we can divide the dataset into bsz parts. + self.n_step = data.size(0) // bsz + + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, self.n_step * bsz) + + # Evenly divide the data across the bsz batches. + self.data = data.view(bsz, -1).t().contiguous().to(device) + + # Number of mini-batches + self.n_batch = (self.n_step + self.bptt - 1) // self.bptt + + def get_batch(self, i, bptt=None): + if bptt is None: bptt = self.bptt + seq_len = min(bptt, self.data.size(0) - 1 - i) + + end_idx = i + seq_len + beg_idx = max(0, i - self.ext_len) + + data = self.data[beg_idx:end_idx] + target = self.data[i+1:i+1+seq_len] + + return data, target, seq_len + + def get_fixlen_iter(self, start=0): + for i in range(start, self.data.size(0) - 1, self.bptt): + yield self.get_batch(i) + + def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): + max_len = self.bptt + max_deviation * std + i = start + while True: + bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. + bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) + data, target, seq_len = self.get_batch(i, bptt) + i += seq_len + yield data, target, seq_len + if i >= self.data.size(0) - 2: + break + + def __iter__(self): + return self.get_fixlen_iter() + + +class LMShuffledIterator(object): + def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): + """ + data -- list[LongTensor] -- there is no order among the LongTensors + """ + self.data = data + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self): + # index iterator + epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ + else np.array(range(len(self.data))) + + # sentence iterator + for idx in epoch_indices: + yield self.data[idx] + + def stream_iterator(self, sent_stream): + # streams for each data in the batch + streams = [None] * self.bsz + + data = torch.LongTensor(self.bptt, self.bsz) + target = torch.LongTensor(self.bptt, self.bsz) + + n_retain = 0 + + while True: + # data : [n_retain+bptt x bsz] + # target : [bptt x bsz] + data[n_retain:].fill_(-1) + target.fill_(-1) + + valid_batch = True + + for i in range(self.bsz): + n_filled = 0 + try: + while n_filled < self.bptt: + if streams[i] is None or len(streams[i]) <= 1: + streams[i] = next(sent_stream) + # number of new tokens to fill in + n_new = min(len(streams[i]) - 1, self.bptt - n_filled) + # first n_retain tokens are retained from last batch + data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ + streams[i][:n_new] + target[n_filled:n_filled+n_new, i] = \ + streams[i][1:n_new+1] + streams[i] = streams[i][n_new:] + n_filled += n_new + except StopIteration: + valid_batch = False + break + + if not valid_batch: + return + + data = data.to(self.device) + target = target.to(self.device) + + yield data, target, self.bptt + + n_retain = min(data.size(0), self.ext_len) + if n_retain > 0: + data[:n_retain] = data[-n_retain:] + data.resize_(n_retain + self.bptt, data.size(1)) + + def __iter__(self): + # sent_stream is an iterator + sent_stream = self.get_sent_stream() + + for batch in self.stream_iterator(sent_stream): + yield batch + + +class LMMultiFileIterator(LMShuffledIterator): + def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, + shuffle=False): + + self.paths = paths + self.vocab = vocab + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self, path): + sents = self.vocab.encode_file(path, add_double_eos=True) + if self.shuffle: + np.random.shuffle(sents) + sent_stream = iter(sents) + + return sent_stream + + def __iter__(self): + if self.shuffle: + np.random.shuffle(self.paths) + + for path in self.paths: + # sent_stream is an iterator + sent_stream = self.get_sent_stream(path) + for batch in self.stream_iterator(sent_stream): + yield batch + + +class Corpus(object): + def __init__(self, path, dataset, *args, **kwargs): + self.dataset = dataset + self.vocab = Vocab(*args, **kwargs) + + if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: + self.vocab.count_file(os.path.join(path, 'train.txt')) + self.vocab.count_file(os.path.join(path, 'valid.txt')) + self.vocab.count_file(os.path.join(path, 'test.txt')) + elif self.dataset == 'wt103': + self.vocab.count_file(os.path.join(path, 'train.txt')) + elif self.dataset == 'lm1b': + train_path_pattern = os.path.join( + path, '1-billion-word-language-modeling-benchmark-r13output', + 'training-monolingual.tokenized.shuffled', 'news.en-*') + train_paths = glob.glob(train_path_pattern) + # the vocab will load from file when build_vocab() is called + + self.vocab.build_vocab() + + if self.dataset in ['ptb', 'wt2', 'wt103']: + self.train = self.vocab.encode_file( + os.path.join(path, 'train.txt'), ordered=True) + self.valid = self.vocab.encode_file( + os.path.join(path, 'valid.txt'), ordered=True) + self.test = self.vocab.encode_file( + os.path.join(path, 'test.txt'), ordered=True) + elif self.dataset in ['enwik8', 'text8']: + self.train = self.vocab.encode_file( + os.path.join(path, 'train.txt'), ordered=True, add_eos=False) + self.valid = self.vocab.encode_file( + os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) + self.test = self.vocab.encode_file( + os.path.join(path, 'test.txt'), ordered=True, add_eos=False) + elif self.dataset == 'lm1b': + self.train = train_paths + self.valid = self.vocab.encode_file( + os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) + self.test = self.vocab.encode_file( + os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) + + def get_iterator(self, split, *args, **kwargs): + if split == 'train': + if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: + data_iter = LMOrderedIterator(self.train, *args, **kwargs) + elif self.dataset == 'lm1b': + kwargs['shuffle'] = True + data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) + elif split in ['valid', 'test']: + data = self.valid if split == 'valid' else self.test + if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: + data_iter = LMOrderedIterator(data, *args, **kwargs) + elif self.dataset == 'lm1b': + data_iter = LMShuffledIterator(data, *args, **kwargs) + + return data_iter + + +def get_lm_corpus(datadir, dataset): + fn = os.path.join(datadir, 'cache.pt') + fn_pickle = os.path.join(datadir, 'cache.pkl') + if os.path.exists(fn): + print('Loading cached dataset...') + corpus = torch.load(fn_pickle) + elif os.path.exists(fn): + print('Loading cached dataset from pickle...') + with open(fn, "rb") as fp: + corpus = pickle.load(fp) + else: + print('Producing dataset {}...'.format(dataset)) + kwargs = {} + if dataset in ['wt103', 'wt2']: + kwargs['special'] = [''] + kwargs['lower_case'] = False + elif dataset == 'ptb': + kwargs['special'] = [''] + kwargs['lower_case'] = True + elif dataset == 'lm1b': + kwargs['special'] = [] + kwargs['lower_case'] = False + kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') + elif dataset in ['enwik8', 'text8']: + pass + + corpus = Corpus(datadir, dataset, **kwargs) + torch.save(corpus, fn) + + return corpus