From 8831c6880390e84494b34fc14f938c8a1c9654eb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 16 Jan 2019 10:31:16 +0100 Subject: [PATCH] fixing various parts of model conversion, loading and weights sharing --- examples/eval_transfo_xl.py | 2 +- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 5 +- .../modeling_transfo_xl.py | 513 ++++++++---------- .../tokenization_transfo_xl.py | 6 + 4 files changed, 243 insertions(+), 283 deletions(-) diff --git a/examples/eval_transfo_xl.py b/examples/eval_transfo_xl.py index 15c2665782..92979d1e4a 100644 --- a/examples/eval_transfo_xl.py +++ b/examples/eval_transfo_xl.py @@ -42,7 +42,7 @@ parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model # parser.add_argument('--data', type=str, default='../data/wikitext-103', # help='location of the data corpus') parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', - choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'], + # choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'], help='pretrained model name') parser.add_argument('--split', type=str, default='test', choices=['all', 'valid', 'test'], diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index 6962481adc..b2f8432d3a 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -116,7 +116,8 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) - torch.save(corpus.vocab.__dict__, pytorch_vocab_dump_path) + corpus_vocab_dict = corpus.vocab.__dict__ + torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) corpus_dict_no_vocab = corpus.__dict__ corpus_dict_no_vocab.pop('vocab', None) @@ -139,7 +140,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, model = TransfoXLModel(config) # Build TF to PyTorch weights loading map - tf_to_pt_map = build_tf_to_pytorch_map(model.transformer, config) + tf_to_pt_map = build_tf_to_pytorch_map(model, config) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index 5b80f045a4..0e1f3f8240 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -30,6 +30,7 @@ import collections import torch import torch.nn as nn +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter @@ -40,7 +41,10 @@ from .file_utils import cached_path logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { - 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz", + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", +} +PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-transfo_xl_config.json", } CONFIG_NAME = 'transfo_xl_config.json' WEIGHTS_NAME = 'pytorch_model.bin' @@ -674,99 +678,266 @@ class AdaptiveEmbedding(nn.Module): return embed -class MemTransformerLM(nn.Module): - def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, - dropout, dropatt, tie_weight=True, d_embed=None, - div_val=1, tie_projs=[False], pre_lnorm=False, - tgt_len=None, ext_len=None, mem_len=None, - cutoffs=[], adapt_inp=False, untie_r=False, - same_length=False, attn_type=0, clamp_len=-1, - sample_softmax=-1, **kwargs): - super(MemTransformerLM, self).__init__() - self.n_token = n_token - d_embed = d_model if d_embed is None else d_embed - self.d_embed = d_embed - self.d_model = d_model - self.n_head = n_head - self.d_head = d_head +class TransfoXLPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(TransfoXLPreTrainedModel, self).__init__() + if not isinstance(config, TransfoXLConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. " + "To create a model from a pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config - self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, - div_val=div_val) + def init_weight(self, weight): + if self.config.init == 'uniform': + nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) + elif self.config.init == 'normal': + nn.init.normal_(weight, 0.0, self.config.init_std) - self.drop = nn.Dropout(dropout) + def init_bias(self, bias): + nn.init.constant_(bias, 0.0) - self.n_layer = n_layer + def init_weights(self, m): + """ Initialize the weights. + """ + classname = m.__class__.__name__ + if classname.find('Linear') != -1: + if hasattr(m, 'weight') and m.weight is not None: + self.init_weight(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + self.init_bias(m.bias) + elif classname.find('AdaptiveEmbedding') != -1: + if hasattr(m, 'emb_projs'): + for i in range(len(m.emb_projs)): + if m.emb_projs[i] is not None: + nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) + elif classname.find('Embedding') != -1: + if hasattr(m, 'weight'): + self.init_weight(m.weight) + elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: + if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: + self.init_weight(m.cluster_weight) + if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: + self.init_bias(m.cluster_bias) + if hasattr(m, 'out_projs'): + for i in range(len(m.out_projs)): + if m.out_projs[i] is not None: + nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) + elif classname.find('LayerNorm') != -1: + if hasattr(m, 'weight'): + nn.init.normal_(m.weight, 1.0, self.config.init_std) + if hasattr(m, 'bias') and m.bias is not None: + self.init_bias(m.bias) + elif classname.find('TransformerLM') != -1: + if hasattr(m, 'r_emb'): + self.init_weight(m.r_emb) + if hasattr(m, 'r_w_bias'): + self.init_weight(m.r_w_bias) + if hasattr(m, 'r_r_bias'): + self.init_weight(m.r_r_bias) + if hasattr(m, 'r_bias'): + self.init_bias(m.r_bias) - self.tgt_len = tgt_len - self.mem_len = mem_len - self.ext_len = ext_len - self.max_klen = tgt_len + ext_len + mem_len + def set_num_special_tokens(self, num_special_tokens): + pass - self.attn_type = attn_type + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, + *inputs, **kwargs): + """ + Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. - if not untie_r: + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `transfo-xl` + - a path or url to a pretrained model archive containing: + . `transfo_xl_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + resolved_config_file = cached_path(config_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + archive_file, config_file)) + return None + if resolved_archive_file == archive_file and resolved_config_file == config_file: + logger.info("loading weights file {}".format(archive_file)) + logger.info("loading configuration file {}".format(config_file)) + else: + logger.info("loading weights file {} from cache at {}".format( + archive_file, resolved_archive_file)) + logger.info("loading configuration file {} from cache at {}".format( + config_file, resolved_config_file)) + # Load config + config = TransfoXLConfig.from_json_file(resolved_config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + state_dict = torch.load(resolved_archive_file) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + # load(model.transformer if hasattr(model, 'transformer') else model, prefix='') + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + return model + + +class TransfoXLModel(TransfoXLPreTrainedModel): + def __init__(self, config): + # n_token, n_layer, n_head, d_model, d_head, d_inner, + # dropout, dropatt, tie_weight=True, d_embed=None, + # div_val=1, tie_projs=[False], pre_lnorm=False, + # tgt_len=None, ext_len=None, mem_len=None, + # cutoffs=[], adapt_inp=False, untie_r=False, + # same_length=False, attn_type=0, clamp_len=-1, + # sample_softmax=-1, **kwargs): + super(TransfoXLModel, self).__init__(config) + self.n_token = config.n_token + + self.d_embed = config.d_embed + self.d_model = config.d_model + self.n_head = config.n_head + self.d_head = config.d_head + + self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, + div_val=config.div_val) + + self.drop = nn.Dropout(config.dropout) + + self.n_layer = config.n_layer + + self.tgt_len = config.tgt_len + self.mem_len = config.mem_len + self.ext_len = config.ext_len + self.max_klen = config.tgt_len + config.ext_len + config.mem_len + + self.attn_type = config.attn_type + + if not config.untie_r: self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) self.layers = nn.ModuleList() - if attn_type == 0: # the default attention - for i in range(n_layer): + if config.attn_type == 0: # the default attention + for i in range(config.n_layer): self.layers.append( RelPartialLearnableDecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, - dropatt=dropatt, pre_lnorm=pre_lnorm, - r_w_bias=None if untie_r else self.r_w_bias, - r_r_bias=None if untie_r else self.r_r_bias) + config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, + tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, + dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, + r_w_bias=None if config.untie_r else self.r_w_bias, + r_r_bias=None if config.untie_r else self.r_r_bias) ) - elif attn_type == 1: # learnable embeddings - for i in range(n_layer): + elif config.attn_type == 1: # learnable embeddings + for i in range(config.n_layer): self.layers.append( RelLearnableDecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, - dropatt=dropatt, pre_lnorm=pre_lnorm, - r_w_bias=None if untie_r else self.r_w_bias, - r_r_bias=None if untie_r else self.r_r_bias) + config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, + tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, + dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, + r_w_bias=None if config.untie_r else self.r_w_bias, + r_r_bias=None if config.untie_r else self.r_r_bias) ) - elif attn_type in [2, 3]: # absolute embeddings - for i in range(n_layer): + elif config.attn_type in [2, 3]: # absolute embeddings + for i in range(config.n_layer): self.layers.append( DecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - dropatt=dropatt, pre_lnorm=pre_lnorm, - r_w_bias=None if untie_r else self.r_w_bias, - r_r_bias=None if untie_r else self.r_r_bias) + config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, + dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, + r_w_bias=None if config.untie_r else self.r_w_bias, + r_r_bias=None if config.untie_r else self.r_r_bias) ) - self.sample_softmax = sample_softmax + self.sample_softmax = config.sample_softmax # use sampled softmax - if sample_softmax > 0: - self.out_layer = nn.Linear(d_model, n_token) - if tie_weight: + if config.sample_softmax > 0: + self.out_layer = nn.Linear(config.d_model, config.n_token) + if config.tie_weight: self.out_layer.weight = self.word_emb.weight - self.tie_weight = tie_weight - self.sampler = LogUniformSampler(n_token, sample_softmax) + self.tie_weight = config.tie_weight + self.sampler = LogUniformSampler(config.n_token, config.sample_softmax) # use adaptive softmax (including standard softmax) else: - self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, - cutoffs, div_val=div_val) + self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, + config.cutoffs, div_val=config.div_val) - if tie_weight: + if config.tie_weight: for i in range(len(self.crit.out_layers)): self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight - if tie_projs: - for i, tie_proj in enumerate(tie_projs): - if tie_proj and div_val == 1 and d_model != d_embed: + if config.tie_projs: + for i, tie_proj in enumerate(config.tie_projs): + if tie_proj and config.div_val == 1 and config.d_model != config.d_embed: self.crit.out_projs[i] = self.word_emb.emb_projs[0] - elif tie_proj and div_val != 1: + elif tie_proj and config.div_val != 1: self.crit.out_projs[i] = self.word_emb.emb_projs[i] - self.same_length = same_length - self.clamp_len = clamp_len + self.same_length = config.same_length + self.clamp_len = config.clamp_len if self.attn_type == 0: # default attention self.pos_emb = PositionalEmbedding(self.d_model) @@ -859,8 +1030,7 @@ class MemTransformerLM(nn.Module): hids.append(core_out) for i, layer in enumerate(self.layers): mems_i = None if mems is None else mems[i] - core_out = layer(core_out, pos_emb, self.r_w_bias, - self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) + core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i) hids.append(core_out) elif self.attn_type == 1: # learnable core_out = self.drop(word_emb) @@ -949,220 +1119,3 @@ class MemTransformerLM(nn.Module): else: return [loss] + new_mems - -class TransfoXLPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. - """ - def __init__(self, config, *inputs, **kwargs): - super(TransfoXLPreTrainedModel, self).__init__() - if not isinstance(config, TransfoXLConfig): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. " - "To create a model from a pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, self.__class__.__name__ - )) - self.config = config - - def init_weight(self, weight): - if self.config.init == 'uniform': - nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) - elif self.config.init == 'normal': - nn.init.normal_(weight, 0.0, self.config.init_std) - - def init_bias(self, bias): - nn.init.constant_(bias, 0.0) - - def init_weights(self, m): - """ Initialize the weights. - """ - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - if hasattr(m, 'weight') and m.weight is not None: - self.init_weight(m.weight) - if hasattr(m, 'bias') and m.bias is not None: - self.init_bias(m.bias) - elif classname.find('AdaptiveEmbedding') != -1: - if hasattr(m, 'emb_projs'): - for i in range(len(m.emb_projs)): - if m.emb_projs[i] is not None: - nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) - elif classname.find('Embedding') != -1: - if hasattr(m, 'weight'): - self.init_weight(m.weight) - elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: - if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: - self.init_weight(m.cluster_weight) - if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: - self.init_bias(m.cluster_bias) - if hasattr(m, 'out_projs'): - for i in range(len(m.out_projs)): - if m.out_projs[i] is not None: - nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) - elif classname.find('LayerNorm') != -1: - if hasattr(m, 'weight'): - nn.init.normal_(m.weight, 1.0, self.config.init_std) - if hasattr(m, 'bias') and m.bias is not None: - self.init_bias(m.bias) - elif classname.find('TransformerLM') != -1: - if hasattr(m, 'r_emb'): - self.init_weight(m.r_emb) - if hasattr(m, 'r_w_bias'): - self.init_weight(m.r_w_bias) - if hasattr(m, 'r_r_bias'): - self.init_weight(m.r_r_bias) - if hasattr(m, 'r_bias'): - self.init_bias(m.r_bias) - - def set_num_special_tokens(self, num_special_tokens): - pass - - @classmethod - def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, - *inputs, **kwargs): - """ - Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `transfo-xl` - - a path or url to a pretrained model archive containing: - . `transfo_xl_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] - else: - archive_file = pretrained_model_name - # redirect to the cache, if necessary - try: - resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) - except FileNotFoundError: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name, - ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), - archive_file)) - return None - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file): - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: - archive.extractall(tempdir) - serialization_dir = tempdir - # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = TransfoXLConfig.from_json_file(config_file) - logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) - - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - # load(model.transformer if hasattr(model, 'transformer') else model, prefix='') - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - return model - - -class TransfoXLModel(TransfoXLPreTrainedModel): - """ Transformer XL model - From "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - by Zihang Dai*, Zhilin Yang*, Yiming Yang, William W. Cohen, Jaime Carbonell, - Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution) - - Params: - config: a TransfoXLConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] - were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[ - `position_ids`: an optional torch.LongTensor with the same shape as input_ids - with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[. - `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids - You can use it to add a third embedding (the previous two being the word and position embeddings) - to each token in the sentence. - - Outputs: - `hidden_states`: the encoded-hidden-states at the top of the model - as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] - (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids) - - Example usage: - ```python - # Already been converted into BPE token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - - config = modeling_transfo_xl.TransfoXLConfig() - - model = modeling_transfo_xl.TransfoXLModel(config) - hidden_states = model(input_ids) - ``` - """ - def __init__(self, config): - super(TransfoXLModel, self).__init__(config) - self.transformer = MemTransformerLM(**config.to_dict()) - self.apply(self.init_weights) - - def forward(self, input_ids, position_ids=None, token_type_ids=None): - return self.transformer(input_ids, position_ids, token_type_ids) diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index a411c267b9..db626f7755 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -444,6 +444,12 @@ class TransfoXLCorpus(object): for key, value in corpus_dict.items(): corpus.__dict__[key] = value corpus.vocab = vocab + if corpus.train is not None: + corpus.train = torch.tensor(corpus.train, dtype=torch.long) + if corpus.valid is not None: + corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) + if corpus.test is not None: + corpus.test = torch.tensor(corpus.test, dtype=torch.long) return corpus def __init__(self, *args, **kwargs):