From c306869ea2cfeebadd64779408ef7a28132779c9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 7 Feb 2019 17:07:03 +0100 Subject: [PATCH] add two transformer xl models --- pytorch_pretrained_bert/__init__.py | 2 +- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 10 +- .../modeling_transfo_xl.py | 226 +++++++++++++----- 3 files changed, 174 insertions(+), 64 deletions(-) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index e4b9c1a116..761af86b6d 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -11,7 +11,7 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, load_tf_weights_in_openai_gpt) -from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, +from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl) from .optimization import BertAdam diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index dedea33435..dae1248f71 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -27,7 +27,7 @@ import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, WEIGHTS_NAME, TransfoXLConfig, - TransfoXLModel, + TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl) from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_NAME) @@ -37,7 +37,7 @@ if sys.version_info[0] == 2: else: import pickle -# We do this to be able to load the python 2 datasets pickles +# We do this to be able to load python 2 datasets pickles # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Corpus = data_utils.TransfoXLCorpus @@ -49,6 +49,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, pytorch_dump_folder_path, transfo_xl_dataset_file): if transfo_xl_dataset_file: + # Convert a pre-processed corpus (see original TensorFlow repo) with open(transfo_xl_dataset_file, "rb") as fp: corpus = pickle.load(fp, encoding="latin1") # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) @@ -64,18 +65,18 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) if tf_checkpoint_path: + # Convert a pre-trained TensorFlow model config_path = os.path.abspath(transfo_xl_config_file) tf_path = os.path.abspath(tf_checkpoint_path) print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) # Initialise PyTorch model - # Construct model if transfo_xl_config_file == "": config = TransfoXLConfig() else: config = TransfoXLConfig(transfo_xl_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) - model = TransfoXLModel(config) + model = TransfoXLLMHeadModel(config) model = load_tf_weights_in_transfo_xl(model, config, tf_path) # Save pytorch-model @@ -90,7 +91,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, if __name__ == "__main__": parser = argparse.ArgumentParser() - ## Required parameters parser.add_argument("--pytorch_dump_folder_path", default = None, type = str, diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 53ebca6e92..f3a3eb46fe 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -57,7 +57,7 @@ def build_tf_to_pytorch_map(model, config): This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible. """ tf_to_pt_map = {} - # Embeddings cutoffs + # Embeddings for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): layer_str = "transformer/adaptive_embed/cutoff_%d/" % i tf_to_pt_map.update({ @@ -934,11 +934,11 @@ class TransfoXLPreTrainedModel(nn.Module): # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(resolved_archive_file) + state_dict = torch.load(resolved_archive_file, map_location='cpu' if not torch.cuda.is_available() else None) if from_tf: # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) - return load_tf_weights_in_transfo_xl(model, weights_path) + return load_tf_weights_in_transfo_xl(model, config, pretrained_model_name_or_path) + missing_keys = [] unexpected_keys = [] error_msgs = [] @@ -965,18 +965,49 @@ class TransfoXLPreTrainedModel(nn.Module): if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) + # Make sure we are still sharing the input and output embeddings + if model.hasattr('tie_weights'): + model.tie_weights() return model class TransfoXLModel(TransfoXLPreTrainedModel): + """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"). + + Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: + - you don't need to specify positioning embeddings indices + - the tokens in the vocabulary have to be sorted to decreasing frequency. + + Params: + config: a TransfoXLConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [sequence_length, batch_size] + with the token indices selected in the range [0, self.config.n_token[ + + Outputs: + A tuple of (last_hidden_state, new_mems) + `last_hidden_state`: the encoded-hidden-states at the top of the model + as a torch.FloatTensor of size [sequence_length, batch_size, self.config.d_model] + `new_mems`: list (num layers) of updated mem states at the entry of each layer + each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model] + + Example usage: + ```python + # Already been converted into BPE token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]]) + + config = TransfoXLConfig() + + model = TransfoXLModel(config) + last_hidden_state, new_mems = model(input_ids) + + # Another time on input_ids_next using the memory: + last_hidden_state, new_mems = model(input_ids_next, new_mems) + ``` + """ def __init__(self, config): - # n_token, n_layer, n_head, d_model, d_head, d_inner, - # dropout, dropatt, tie_weight=True, d_embed=None, - # div_val=1, tie_projs=[False], pre_lnorm=False, - # tgt_len=None, ext_len=None, mem_len=None, - # cutoffs=[], adapt_inp=False, untie_r=False, - # same_length=False, attn_type=0, clamp_len=-1, - # sample_softmax=-1, **kwargs): super(TransfoXLModel, self).__init__(config) self.n_token = config.n_token @@ -1034,31 +1065,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): r_r_bias=None if config.untie_r else self.r_r_bias) ) - self.sample_softmax = config.sample_softmax - # use sampled softmax - if config.sample_softmax > 0: - self.out_layer = nn.Linear(config.d_model, config.n_token) - if config.tie_weight: - self.out_layer.weight = self.word_emb.weight - self.tie_weight = config.tie_weight - self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) - - # use adaptive softmax (including standard softmax) - else: - self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, - config.cutoffs, div_val=config.div_val) - - if config.tie_weight: - for i in range(len(self.crit.out_layers)): - self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight - - if config.tie_projs: - for i, tie_proj in enumerate(config.tie_projs): - if tie_proj and config.div_val == 1 and config.d_model != config.d_embed: - self.crit.out_projs[i] = self.word_emb.emb_projs[0] - elif tie_proj and config.div_val != 1: - self.crit.out_projs[i] = self.word_emb.emb_projs[i] - self.same_length = config.same_length self.clamp_len = config.clamp_len @@ -1074,6 +1080,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): elif self.attn_type == 3: # absolute deeper SA self.r_emb = nn.Parameter(torch.Tensor( self.n_layer, self.max_klen, self.n_head, self.d_head)) + self.apply(self.init_weights) def backward_compatible(self): self.sample_softmax = -1 @@ -1210,32 +1217,135 @@ class TransfoXLModel(TransfoXLPreTrainedModel): return core_out, new_mems - def forward(self, data, target=None, *mems): - # nn.DataParallel does not allow size(0) tensors to be broadcasted. - # So, have to initialize size(0) mems inside the model forward. - # Moreover, have to return new_mems to allow nn.DataParallel to piece - # them together. - if not mems: - mems = self.init_mems(data) + def forward(self, input_ids, mems=None): + """ Params: + input_ids :: [len, bsz] + Returns: + tuple (last_hidden, new_mems) where: + new_mems: list (num layers) of mem states at the entry of each layer + shape :: [self.config.mem_len, bsz, self.config.d_model] + last_hidden: output of the last layer: + shape :: [len, bsz, self.config.d_model] + """ + if mems is None: + mems = self.init_mems(input_ids) + last_hidden, new_mems = self._forward(input_ids, mems=mems) + return (last_hidden, new_mems) - hidden, new_mems = self._forward(data, mems=mems) - if target is None: - if new_mems is None: - return [hidden] + +class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): + """Transformer XL model ("Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"). + + This model add an (adaptive) softmax head on top of the TransfoXLModel + + Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: + - you don't need to specify positioning embeddings indices + - the tokens in the vocabulary have to be sorted to decreasing frequency. + + Call self.tie_weights() if you update/load the weights of the transformer to keep the weights tied. + + Params: + config: a TransfoXLConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [sequence_length, batch_size] + with the token indices selected in the range [0, self.config.n_token[ + `target`: a torch.LongTensor of shape [sequence_length, batch_size] + with the target token indices selected in the range [0, self.config.n_token[ + + Outputs: + A tuple of (last_hidden_state, new_mems) + `softmax_output`: output of the (adaptive) softmax: + if target is None: + Negative log likelihood of shape :: [len, bsz] else: - return [hidden] + new_mems + log probabilities of tokens, shape :: [len, bsz, n_tokens] + `new_mems`: list (num layers) of updated mem states at the entry of each layer + each mem state is a torch.FloatTensor of size [self.config.mem_len, batch_size, self.config.d_model] - tgt_len = target.size(0) - pred_hid = hidden[-tgt_len:] + Example usage: + ```python + # Already been converted into BPE token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_ids_next = torch.LongTensor([[53, 21, 1], [64, 23, 100]]) + + config = TransfoXLConfig() + + model = TransfoXLModel(config) + last_hidden_state, new_mems = model(input_ids) + + # Another time on input_ids_next using the memory: + last_hidden_state, new_mems = model(input_ids_next, new_mems) + ``` + """ + def __init__(self, config): + super(TransfoXLLMHeadModel, self).__init__(config) + self.transformer = TransfoXLModel(config) + self.sample_softmax = config.sample_softmax + # use sampled softmax + if config.sample_softmax > 0: + self.out_layer = nn.Linear(config.d_model, config.n_token) + self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) + # use adaptive softmax (including standard softmax) + else: + self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, + config.cutoffs, div_val=config.div_val) + self.apply(self.init_weights) + self.tie_weights() + + def tie_weights(self): + """ Run this to be sure output and input (adaptive) softmax weights are tied """ + # sampled softmax + if self.sample_softmax > 0: + if self.config.tie_weight: + self.out_layer.weight = self.transformer.word_emb.weight + # adaptive softmax (including standard softmax) + else: + if self.config.tie_weight: + for i in range(len(self.crit.out_layers)): + self.crit.out_layers[i].weight = self.transformer.word_emb.emb_layers[i].weight + if self.config.tie_projs: + for i, tie_proj in enumerate(self.config.tie_projs): + if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: + self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] + elif tie_proj and self.config.div_val != 1: + self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] + + def reset_length(self, tgt_len, ext_len, mem_len): + self.transformer.reset_length(tgt_len, ext_len, mem_len) + + def init_mems(self, data): + return self.transformer.init_mems(data) + + def forward(self, input_ids, target=None, mems=None): + """ Params: + input_ids :: [len, bsz] + target :: [len, bsz] + Returns: + tuple(softmax_output, new_mems) where: + new_mems: list (num layers) of hidden states at the entry of each layer + shape :: [mem_len, bsz, self.config.d_model] + softmax_output: output of the (adaptive) softmax: + if target is None: + Negative log likelihood of shape :: [len, bsz] + else: + log probabilities of tokens, shape :: [len, bsz, n_tokens] + """ + bsz = input_ids.size(1) + tgt_len = input_ids.size(0) + + last_hidden, new_mems = self.transformer(input_ids, mems) + + pred_hid = last_hidden[-tgt_len:] if self.sample_softmax > 0 and self.training: - assert self.tie_weight - logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) + assert self.config.tie_weight + logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, target, pred_hid, self.sampler) loss = -F.log_softmax(logit, -1)[:, :, 0] else: - loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) - loss = loss.view(tgt_len, -1) + softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target) + if target is None: + softmax_output = softmax_output.view(tgt_len, bsz, -1) + else: + softmax_output = softmax_output.view(tgt_len, bsz) - if new_mems is None: - return [loss] - else: - return (loss, new_mems) + return (softmax_output, new_mems)