embeddings resizing + tie_weights
This commit is contained in:
@@ -104,7 +104,6 @@ class XLMConfig(PretrainedConfig):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=30145,
|
||||
n_special=0,
|
||||
emb_dim=2048,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
@@ -148,7 +147,6 @@ class XLMConfig(PretrainedConfig):
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.n_words = vocab_size_or_config_json_file
|
||||
self.n_special = n_special
|
||||
self.emb_dim = emb_dim
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
@@ -183,8 +181,8 @@ class XLMConfig(PretrainedConfig):
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@property
|
||||
def total_tokens_embeddings(self):
|
||||
return self.n_words + self.n_special
|
||||
def vocab_size(self):
|
||||
return self.n_words
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
@@ -479,6 +477,9 @@ class XLMModel(XLMPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@@ -718,8 +719,6 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLMWithLMHeadModel, self).__init__(config)
|
||||
self.torchscript = config.torchscript
|
||||
|
||||
self.transformer = XLMModel(config)
|
||||
self.pred_layer = XLMPredLayer(config)
|
||||
|
||||
@@ -729,7 +728,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the embeddings
|
||||
"""
|
||||
if self.torchscript:
|
||||
if self.config.torchscript:
|
||||
self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
|
||||
else:
|
||||
self.pred_layer.proj.weight = self.transformer.embeddings.weight
|
||||
|
||||
Reference in New Issue
Block a user