updating tests

This commit is contained in:
thomwolf
2019-07-12 10:57:58 +02:00
parent 3fbceed8d2
commit 2918b7d2a0
14 changed files with 672 additions and 596 deletions

View File

@@ -184,6 +184,10 @@ class XLMConfig(PretrainedConfig):
def vocab_size(self):
return self.n_words
@vocab_size.setter
def vocab_size(self, value):
self.n_words = value
@property
def hidden_size(self):
return self.emb_dim
@@ -479,6 +483,7 @@ class XLMModel(XLMPreTrainedModel):
def _resize_token_embeddings(self, new_num_tokens):
self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens)
return self.embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
@@ -728,10 +733,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
if self.config.torchscript:
self.pred_layer.proj.weight = nn.Parameter(self.transformer.embeddings.weight.clone())
else:
self.pred_layer.proj.weight = self.transformer.embeddings.weight
self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings)
def forward(self, input_ids, lengths=None, positions=None, langs=None, token_type_ids=None,
attention_mask=None, cache=None, labels=None, head_mask=None):