|
|
|
|
@@ -30,6 +30,7 @@ import collections
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch.nn import CrossEntropyLoss
|
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
|
|
|
|
@@ -40,7 +41,10 @@ from .file_utils import cached_path
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
|
|
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103.tar.gz",
|
|
|
|
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
|
|
|
|
|
}
|
|
|
|
|
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
|
|
|
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-transfo_xl_config.json",
|
|
|
|
|
}
|
|
|
|
|
CONFIG_NAME = 'transfo_xl_config.json'
|
|
|
|
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
|
|
|
@@ -674,99 +678,266 @@ class AdaptiveEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
return embed
|
|
|
|
|
|
|
|
|
|
class MemTransformerLM(nn.Module):
|
|
|
|
|
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
|
|
|
|
|
dropout, dropatt, tie_weight=True, d_embed=None,
|
|
|
|
|
div_val=1, tie_projs=[False], pre_lnorm=False,
|
|
|
|
|
tgt_len=None, ext_len=None, mem_len=None,
|
|
|
|
|
cutoffs=[], adapt_inp=False, untie_r=False,
|
|
|
|
|
same_length=False, attn_type=0, clamp_len=-1,
|
|
|
|
|
sample_softmax=-1, **kwargs):
|
|
|
|
|
super(MemTransformerLM, self).__init__()
|
|
|
|
|
self.n_token = n_token
|
|
|
|
|
|
|
|
|
|
d_embed = d_model if d_embed is None else d_embed
|
|
|
|
|
self.d_embed = d_embed
|
|
|
|
|
self.d_model = d_model
|
|
|
|
|
self.n_head = n_head
|
|
|
|
|
self.d_head = d_head
|
|
|
|
|
class TransfoXLPreTrainedModel(nn.Module):
|
|
|
|
|
""" An abstract class to handle weights initialization and
|
|
|
|
|
a simple interface for dowloading and loading pretrained models.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
|
|
|
super(TransfoXLPreTrainedModel, self).__init__()
|
|
|
|
|
if not isinstance(config, TransfoXLConfig):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
|
|
|
|
|
"To create a model from a pretrained model use "
|
|
|
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
|
|
|
self.__class__.__name__, self.__class__.__name__
|
|
|
|
|
))
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
|
|
|
|
|
div_val=div_val)
|
|
|
|
|
def init_weight(self, weight):
|
|
|
|
|
if self.config.init == 'uniform':
|
|
|
|
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
|
|
|
|
elif self.config.init == 'normal':
|
|
|
|
|
nn.init.normal_(weight, 0.0, self.config.init_std)
|
|
|
|
|
|
|
|
|
|
self.drop = nn.Dropout(dropout)
|
|
|
|
|
def init_bias(self, bias):
|
|
|
|
|
nn.init.constant_(bias, 0.0)
|
|
|
|
|
|
|
|
|
|
self.n_layer = n_layer
|
|
|
|
|
def init_weights(self, m):
|
|
|
|
|
""" Initialize the weights.
|
|
|
|
|
"""
|
|
|
|
|
classname = m.__class__.__name__
|
|
|
|
|
if classname.find('Linear') != -1:
|
|
|
|
|
if hasattr(m, 'weight') and m.weight is not None:
|
|
|
|
|
self.init_weight(m.weight)
|
|
|
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
|
|
|
self.init_bias(m.bias)
|
|
|
|
|
elif classname.find('AdaptiveEmbedding') != -1:
|
|
|
|
|
if hasattr(m, 'emb_projs'):
|
|
|
|
|
for i in range(len(m.emb_projs)):
|
|
|
|
|
if m.emb_projs[i] is not None:
|
|
|
|
|
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
|
|
|
|
|
elif classname.find('Embedding') != -1:
|
|
|
|
|
if hasattr(m, 'weight'):
|
|
|
|
|
self.init_weight(m.weight)
|
|
|
|
|
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
|
|
|
|
|
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
|
|
|
|
|
self.init_weight(m.cluster_weight)
|
|
|
|
|
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
|
|
|
|
|
self.init_bias(m.cluster_bias)
|
|
|
|
|
if hasattr(m, 'out_projs'):
|
|
|
|
|
for i in range(len(m.out_projs)):
|
|
|
|
|
if m.out_projs[i] is not None:
|
|
|
|
|
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
|
|
|
|
|
elif classname.find('LayerNorm') != -1:
|
|
|
|
|
if hasattr(m, 'weight'):
|
|
|
|
|
nn.init.normal_(m.weight, 1.0, self.config.init_std)
|
|
|
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
|
|
|
self.init_bias(m.bias)
|
|
|
|
|
elif classname.find('TransformerLM') != -1:
|
|
|
|
|
if hasattr(m, 'r_emb'):
|
|
|
|
|
self.init_weight(m.r_emb)
|
|
|
|
|
if hasattr(m, 'r_w_bias'):
|
|
|
|
|
self.init_weight(m.r_w_bias)
|
|
|
|
|
if hasattr(m, 'r_r_bias'):
|
|
|
|
|
self.init_weight(m.r_r_bias)
|
|
|
|
|
if hasattr(m, 'r_bias'):
|
|
|
|
|
self.init_bias(m.r_bias)
|
|
|
|
|
|
|
|
|
|
self.tgt_len = tgt_len
|
|
|
|
|
self.mem_len = mem_len
|
|
|
|
|
self.ext_len = ext_len
|
|
|
|
|
self.max_klen = tgt_len + ext_len + mem_len
|
|
|
|
|
def set_num_special_tokens(self, num_special_tokens):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
self.attn_type = attn_type
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
|
|
|
|
*inputs, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
|
|
|
|
Download and cache the pre-trained model file if needed.
|
|
|
|
|
|
|
|
|
|
if not untie_r:
|
|
|
|
|
Params:
|
|
|
|
|
pretrained_model_name_or_path: either:
|
|
|
|
|
- a str with the name of a pre-trained model to load selected in the list of:
|
|
|
|
|
. `transfo-xl`
|
|
|
|
|
- a path or url to a pretrained model archive containing:
|
|
|
|
|
. `transfo_xl_config.json` a configuration file for the model
|
|
|
|
|
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
|
|
|
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
|
|
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
|
|
|
|
*inputs, **kwargs: additional input for the specific Bert class
|
|
|
|
|
(ex: num_labels for BertForSequenceClassification)
|
|
|
|
|
"""
|
|
|
|
|
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
|
|
|
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
|
|
|
|
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
|
|
|
|
else:
|
|
|
|
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
|
|
|
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
|
|
|
|
# redirect to the cache, if necessary
|
|
|
|
|
try:
|
|
|
|
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
|
|
|
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
logger.error(
|
|
|
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
|
|
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
|
|
|
|
"at this path or url.".format(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
archive_file, config_file))
|
|
|
|
|
return None
|
|
|
|
|
if resolved_archive_file == archive_file and resolved_config_file == config_file:
|
|
|
|
|
logger.info("loading weights file {}".format(archive_file))
|
|
|
|
|
logger.info("loading configuration file {}".format(config_file))
|
|
|
|
|
else:
|
|
|
|
|
logger.info("loading weights file {} from cache at {}".format(
|
|
|
|
|
archive_file, resolved_archive_file))
|
|
|
|
|
logger.info("loading configuration file {} from cache at {}".format(
|
|
|
|
|
config_file, resolved_config_file))
|
|
|
|
|
# Load config
|
|
|
|
|
config = TransfoXLConfig.from_json_file(resolved_config_file)
|
|
|
|
|
logger.info("Model config {}".format(config))
|
|
|
|
|
# Instantiate model.
|
|
|
|
|
model = cls(config, *inputs, **kwargs)
|
|
|
|
|
if state_dict is None:
|
|
|
|
|
state_dict = torch.load(resolved_archive_file)
|
|
|
|
|
|
|
|
|
|
old_keys = []
|
|
|
|
|
new_keys = []
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
new_key = None
|
|
|
|
|
if 'gamma' in key:
|
|
|
|
|
new_key = key.replace('gamma', 'weight')
|
|
|
|
|
if 'beta' in key:
|
|
|
|
|
new_key = key.replace('beta', 'bias')
|
|
|
|
|
if new_key:
|
|
|
|
|
old_keys.append(key)
|
|
|
|
|
new_keys.append(new_key)
|
|
|
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
|
|
|
|
missing_keys = []
|
|
|
|
|
unexpected_keys = []
|
|
|
|
|
error_msgs = []
|
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
|
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
|
|
|
state_dict = state_dict.copy()
|
|
|
|
|
if metadata is not None:
|
|
|
|
|
state_dict._metadata = metadata
|
|
|
|
|
|
|
|
|
|
def load(module, prefix=''):
|
|
|
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
|
|
|
module._load_from_state_dict(
|
|
|
|
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
for name, child in module._modules.items():
|
|
|
|
|
if child is not None:
|
|
|
|
|
load(child, prefix + name + '.')
|
|
|
|
|
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
|
|
|
|
if len(missing_keys) > 0:
|
|
|
|
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
|
|
|
|
model.__class__.__name__, missing_keys))
|
|
|
|
|
if len(unexpected_keys) > 0:
|
|
|
|
|
logger.info("Weights from pretrained model not used in {}: {}".format(
|
|
|
|
|
model.__class__.__name__, unexpected_keys))
|
|
|
|
|
if len(error_msgs) > 0:
|
|
|
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
|
|
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
# n_token, n_layer, n_head, d_model, d_head, d_inner,
|
|
|
|
|
# dropout, dropatt, tie_weight=True, d_embed=None,
|
|
|
|
|
# div_val=1, tie_projs=[False], pre_lnorm=False,
|
|
|
|
|
# tgt_len=None, ext_len=None, mem_len=None,
|
|
|
|
|
# cutoffs=[], adapt_inp=False, untie_r=False,
|
|
|
|
|
# same_length=False, attn_type=0, clamp_len=-1,
|
|
|
|
|
# sample_softmax=-1, **kwargs):
|
|
|
|
|
super(TransfoXLModel, self).__init__(config)
|
|
|
|
|
self.n_token = config.n_token
|
|
|
|
|
|
|
|
|
|
self.d_embed = config.d_embed
|
|
|
|
|
self.d_model = config.d_model
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
|
self.d_head = config.d_head
|
|
|
|
|
|
|
|
|
|
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
|
|
|
|
|
div_val=config.div_val)
|
|
|
|
|
|
|
|
|
|
self.drop = nn.Dropout(config.dropout)
|
|
|
|
|
|
|
|
|
|
self.n_layer = config.n_layer
|
|
|
|
|
|
|
|
|
|
self.tgt_len = config.tgt_len
|
|
|
|
|
self.mem_len = config.mem_len
|
|
|
|
|
self.ext_len = config.ext_len
|
|
|
|
|
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
|
|
|
|
|
|
|
|
|
self.attn_type = config.attn_type
|
|
|
|
|
|
|
|
|
|
if not config.untie_r:
|
|
|
|
|
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
|
|
|
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
|
|
if attn_type == 0: # the default attention
|
|
|
|
|
for i in range(n_layer):
|
|
|
|
|
if config.attn_type == 0: # the default attention
|
|
|
|
|
for i in range(config.n_layer):
|
|
|
|
|
self.layers.append(
|
|
|
|
|
RelPartialLearnableDecoderLayer(
|
|
|
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
|
|
|
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
|
|
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
|
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
|
|
|
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
|
|
|
|
|
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
|
|
|
|
|
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
|
|
|
|
|
r_w_bias=None if config.untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if config.untie_r else self.r_r_bias)
|
|
|
|
|
)
|
|
|
|
|
elif attn_type == 1: # learnable embeddings
|
|
|
|
|
for i in range(n_layer):
|
|
|
|
|
elif config.attn_type == 1: # learnable embeddings
|
|
|
|
|
for i in range(config.n_layer):
|
|
|
|
|
self.layers.append(
|
|
|
|
|
RelLearnableDecoderLayer(
|
|
|
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
|
|
|
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
|
|
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
|
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
|
|
|
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
|
|
|
|
|
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
|
|
|
|
|
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
|
|
|
|
|
r_w_bias=None if config.untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if config.untie_r else self.r_r_bias)
|
|
|
|
|
)
|
|
|
|
|
elif attn_type in [2, 3]: # absolute embeddings
|
|
|
|
|
for i in range(n_layer):
|
|
|
|
|
elif config.attn_type in [2, 3]: # absolute embeddings
|
|
|
|
|
for i in range(config.n_layer):
|
|
|
|
|
self.layers.append(
|
|
|
|
|
DecoderLayer(
|
|
|
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
|
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
|
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
|
|
|
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
|
|
|
|
|
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
|
|
|
|
|
r_w_bias=None if config.untie_r else self.r_w_bias,
|
|
|
|
|
r_r_bias=None if config.untie_r else self.r_r_bias)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.sample_softmax = sample_softmax
|
|
|
|
|
self.sample_softmax = config.sample_softmax
|
|
|
|
|
# use sampled softmax
|
|
|
|
|
if sample_softmax > 0:
|
|
|
|
|
self.out_layer = nn.Linear(d_model, n_token)
|
|
|
|
|
if tie_weight:
|
|
|
|
|
if config.sample_softmax > 0:
|
|
|
|
|
self.out_layer = nn.Linear(config.d_model, config.n_token)
|
|
|
|
|
if config.tie_weight:
|
|
|
|
|
self.out_layer.weight = self.word_emb.weight
|
|
|
|
|
self.tie_weight = tie_weight
|
|
|
|
|
self.sampler = LogUniformSampler(n_token, sample_softmax)
|
|
|
|
|
self.tie_weight = config.tie_weight
|
|
|
|
|
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
|
|
|
|
|
|
|
|
|
|
# use adaptive softmax (including standard softmax)
|
|
|
|
|
else:
|
|
|
|
|
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
|
|
|
|
|
cutoffs, div_val=div_val)
|
|
|
|
|
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
|
|
|
|
|
config.cutoffs, div_val=config.div_val)
|
|
|
|
|
|
|
|
|
|
if tie_weight:
|
|
|
|
|
if config.tie_weight:
|
|
|
|
|
for i in range(len(self.crit.out_layers)):
|
|
|
|
|
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
|
|
|
|
|
|
|
|
|
|
if tie_projs:
|
|
|
|
|
for i, tie_proj in enumerate(tie_projs):
|
|
|
|
|
if tie_proj and div_val == 1 and d_model != d_embed:
|
|
|
|
|
if config.tie_projs:
|
|
|
|
|
for i, tie_proj in enumerate(config.tie_projs):
|
|
|
|
|
if tie_proj and config.div_val == 1 and config.d_model != config.d_embed:
|
|
|
|
|
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
|
|
|
|
|
elif tie_proj and div_val != 1:
|
|
|
|
|
elif tie_proj and config.div_val != 1:
|
|
|
|
|
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
|
|
|
|
|
|
|
|
|
|
self.same_length = same_length
|
|
|
|
|
self.clamp_len = clamp_len
|
|
|
|
|
self.same_length = config.same_length
|
|
|
|
|
self.clamp_len = config.clamp_len
|
|
|
|
|
|
|
|
|
|
if self.attn_type == 0: # default attention
|
|
|
|
|
self.pos_emb = PositionalEmbedding(self.d_model)
|
|
|
|
|
@@ -859,8 +1030,7 @@ class MemTransformerLM(nn.Module):
|
|
|
|
|
hids.append(core_out)
|
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
|
mems_i = None if mems is None else mems[i]
|
|
|
|
|
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
|
|
|
|
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
|
|
|
|
core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
|
|
|
|
hids.append(core_out)
|
|
|
|
|
elif self.attn_type == 1: # learnable
|
|
|
|
|
core_out = self.drop(word_emb)
|
|
|
|
|
@@ -949,220 +1119,3 @@ class MemTransformerLM(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
return [loss] + new_mems
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransfoXLPreTrainedModel(nn.Module):
|
|
|
|
|
""" An abstract class to handle weights initialization and
|
|
|
|
|
a simple interface for dowloading and loading pretrained models.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
|
|
|
super(TransfoXLPreTrainedModel, self).__init__()
|
|
|
|
|
if not isinstance(config, TransfoXLConfig):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
|
|
|
|
|
"To create a model from a pretrained model use "
|
|
|
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
|
|
|
self.__class__.__name__, self.__class__.__name__
|
|
|
|
|
))
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
|
def init_weight(self, weight):
|
|
|
|
|
if self.config.init == 'uniform':
|
|
|
|
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
|
|
|
|
elif self.config.init == 'normal':
|
|
|
|
|
nn.init.normal_(weight, 0.0, self.config.init_std)
|
|
|
|
|
|
|
|
|
|
def init_bias(self, bias):
|
|
|
|
|
nn.init.constant_(bias, 0.0)
|
|
|
|
|
|
|
|
|
|
def init_weights(self, m):
|
|
|
|
|
""" Initialize the weights.
|
|
|
|
|
"""
|
|
|
|
|
classname = m.__class__.__name__
|
|
|
|
|
if classname.find('Linear') != -1:
|
|
|
|
|
if hasattr(m, 'weight') and m.weight is not None:
|
|
|
|
|
self.init_weight(m.weight)
|
|
|
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
|
|
|
self.init_bias(m.bias)
|
|
|
|
|
elif classname.find('AdaptiveEmbedding') != -1:
|
|
|
|
|
if hasattr(m, 'emb_projs'):
|
|
|
|
|
for i in range(len(m.emb_projs)):
|
|
|
|
|
if m.emb_projs[i] is not None:
|
|
|
|
|
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
|
|
|
|
|
elif classname.find('Embedding') != -1:
|
|
|
|
|
if hasattr(m, 'weight'):
|
|
|
|
|
self.init_weight(m.weight)
|
|
|
|
|
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
|
|
|
|
|
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
|
|
|
|
|
self.init_weight(m.cluster_weight)
|
|
|
|
|
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
|
|
|
|
|
self.init_bias(m.cluster_bias)
|
|
|
|
|
if hasattr(m, 'out_projs'):
|
|
|
|
|
for i in range(len(m.out_projs)):
|
|
|
|
|
if m.out_projs[i] is not None:
|
|
|
|
|
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
|
|
|
|
|
elif classname.find('LayerNorm') != -1:
|
|
|
|
|
if hasattr(m, 'weight'):
|
|
|
|
|
nn.init.normal_(m.weight, 1.0, self.config.init_std)
|
|
|
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
|
|
|
self.init_bias(m.bias)
|
|
|
|
|
elif classname.find('TransformerLM') != -1:
|
|
|
|
|
if hasattr(m, 'r_emb'):
|
|
|
|
|
self.init_weight(m.r_emb)
|
|
|
|
|
if hasattr(m, 'r_w_bias'):
|
|
|
|
|
self.init_weight(m.r_w_bias)
|
|
|
|
|
if hasattr(m, 'r_r_bias'):
|
|
|
|
|
self.init_weight(m.r_r_bias)
|
|
|
|
|
if hasattr(m, 'r_bias'):
|
|
|
|
|
self.init_bias(m.r_bias)
|
|
|
|
|
|
|
|
|
|
def set_num_special_tokens(self, num_special_tokens):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
|
|
|
|
*inputs, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
|
|
|
|
Download and cache the pre-trained model file if needed.
|
|
|
|
|
|
|
|
|
|
Params:
|
|
|
|
|
pretrained_model_name: either:
|
|
|
|
|
- a str with the name of a pre-trained model to load selected in the list of:
|
|
|
|
|
. `transfo-xl`
|
|
|
|
|
- a path or url to a pretrained model archive containing:
|
|
|
|
|
. `transfo_xl_config.json` a configuration file for the model
|
|
|
|
|
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
|
|
|
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
|
|
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
|
|
|
|
*inputs, **kwargs: additional input for the specific Bert class
|
|
|
|
|
(ex: num_labels for BertForSequenceClassification)
|
|
|
|
|
"""
|
|
|
|
|
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
|
|
|
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
|
|
|
|
else:
|
|
|
|
|
archive_file = pretrained_model_name
|
|
|
|
|
# redirect to the cache, if necessary
|
|
|
|
|
try:
|
|
|
|
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
logger.error(
|
|
|
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
|
|
|
"We assumed '{}' was a path or url but couldn't find any file "
|
|
|
|
|
"associated to this path or url.".format(
|
|
|
|
|
pretrained_model_name,
|
|
|
|
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
|
|
|
|
archive_file))
|
|
|
|
|
return None
|
|
|
|
|
if resolved_archive_file == archive_file:
|
|
|
|
|
logger.info("loading archive file {}".format(archive_file))
|
|
|
|
|
else:
|
|
|
|
|
logger.info("loading archive file {} from cache at {}".format(
|
|
|
|
|
archive_file, resolved_archive_file))
|
|
|
|
|
tempdir = None
|
|
|
|
|
if os.path.isdir(resolved_archive_file):
|
|
|
|
|
serialization_dir = resolved_archive_file
|
|
|
|
|
else:
|
|
|
|
|
# Extract archive to temp dir
|
|
|
|
|
tempdir = tempfile.mkdtemp()
|
|
|
|
|
logger.info("extracting archive file {} to temp dir {}".format(
|
|
|
|
|
resolved_archive_file, tempdir))
|
|
|
|
|
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
|
|
|
|
archive.extractall(tempdir)
|
|
|
|
|
serialization_dir = tempdir
|
|
|
|
|
# Load config
|
|
|
|
|
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
|
|
|
|
config = TransfoXLConfig.from_json_file(config_file)
|
|
|
|
|
logger.info("Model config {}".format(config))
|
|
|
|
|
# Instantiate model.
|
|
|
|
|
model = cls(config, *inputs, **kwargs)
|
|
|
|
|
if state_dict is None:
|
|
|
|
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
|
|
|
|
state_dict = torch.load(weights_path)
|
|
|
|
|
|
|
|
|
|
old_keys = []
|
|
|
|
|
new_keys = []
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
new_key = None
|
|
|
|
|
if 'gamma' in key:
|
|
|
|
|
new_key = key.replace('gamma', 'weight')
|
|
|
|
|
if 'beta' in key:
|
|
|
|
|
new_key = key.replace('beta', 'bias')
|
|
|
|
|
if new_key:
|
|
|
|
|
old_keys.append(key)
|
|
|
|
|
new_keys.append(new_key)
|
|
|
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
|
|
|
|
missing_keys = []
|
|
|
|
|
unexpected_keys = []
|
|
|
|
|
error_msgs = []
|
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
|
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
|
|
|
state_dict = state_dict.copy()
|
|
|
|
|
if metadata is not None:
|
|
|
|
|
state_dict._metadata = metadata
|
|
|
|
|
|
|
|
|
|
def load(module, prefix=''):
|
|
|
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
|
|
|
module._load_from_state_dict(
|
|
|
|
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
for name, child in module._modules.items():
|
|
|
|
|
if child is not None:
|
|
|
|
|
load(child, prefix + name + '.')
|
|
|
|
|
# load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
|
|
|
|
if len(missing_keys) > 0:
|
|
|
|
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
|
|
|
|
model.__class__.__name__, missing_keys))
|
|
|
|
|
if len(unexpected_keys) > 0:
|
|
|
|
|
logger.info("Weights from pretrained model not used in {}: {}".format(
|
|
|
|
|
model.__class__.__name__, unexpected_keys))
|
|
|
|
|
if len(error_msgs) > 0:
|
|
|
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
|
|
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
|
|
|
|
if tempdir:
|
|
|
|
|
# Clean up temp dir
|
|
|
|
|
shutil.rmtree(tempdir)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
|
|
|
|
""" Transformer XL model
|
|
|
|
|
From "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
|
|
|
by Zihang Dai*, Zhilin Yang*, Yiming Yang, William W. Cohen, Jaime Carbonell,
|
|
|
|
|
Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)
|
|
|
|
|
|
|
|
|
|
Params:
|
|
|
|
|
config: a TransfoXLConfig class instance with the configuration to build a new model
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
|
|
|
|
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
|
|
|
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
|
|
|
|
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
|
|
|
|
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
|
|
|
|
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
|
|
|
|
to each token in the sentence.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
`hidden_states`: the encoded-hidden-states at the top of the model
|
|
|
|
|
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
|
|
|
|
|
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
|
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
```python
|
|
|
|
|
# Already been converted into BPE token ids
|
|
|
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
|
|
|
|
|
|
|
|
config = modeling_transfo_xl.TransfoXLConfig()
|
|
|
|
|
|
|
|
|
|
model = modeling_transfo_xl.TransfoXLModel(config)
|
|
|
|
|
hidden_states = model(input_ids)
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super(TransfoXLModel, self).__init__(config)
|
|
|
|
|
self.transformer = MemTransformerLM(**config.to_dict())
|
|
|
|
|
self.apply(self.init_weights)
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
|
|
|
|
return self.transformer(input_ids, position_ids, token_type_ids)
|
|
|
|
|
|