embeddings resizing + tie_weights

This commit is contained in:
thomwolf
2019-07-12 00:02:49 +02:00
parent 50e62a4cb4
commit bd404735a7
15 changed files with 196 additions and 332 deletions

View File

@@ -104,7 +104,6 @@ class XLMConfig(PretrainedConfig):
def __init__(self,
vocab_size_or_config_json_file=30145,
n_special=0,
emb_dim=2048,
n_layers=12,
n_heads=16,
@@ -148,7 +147,6 @@ class XLMConfig(PretrainedConfig):
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_words = vocab_size_or_config_json_file
self.n_special = n_special
self.emb_dim = emb_dim
self.n_layers = n_layers
self.n_heads = n_heads
@@ -183,8 +181,8 @@ class XLMConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)")
@property
def total_tokens_embeddings(self):
return self.n_words + self.n_special
def vocab_size(self):
return self.n_words
@property
def hidden_size(self):
@@ -479,6 +477,9 @@ class XLMModel(XLMPreTrainedModel):
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@@ -718,8 +719,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
"""
def __init__(self, config):
super(XLMWithLMHeadModel, self).__init__(config)
self.torchscript = config.torchscript
self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config)
@@ -729,7 +728,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
if self.torchscript:
if self.config.torchscript:
self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
else:
self.pred_layer.proj.weight = self.transformer.embeddings.weight