updating tests

This commit is contained in:
thomwolf
2019-07-12 10:57:58 +02:00
parent 3fbceed8d2
commit 2918b7d2a0
14 changed files with 672 additions and 596 deletions

View File

@@ -617,6 +617,7 @@ class BertModel(BertPreTrainedModel):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
@@ -758,11 +759,8 @@ class BertForPreTraining(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.bert.embeddings.word_embeddings.weight
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None):
@@ -864,11 +862,8 @@ class BertForMaskedLM(BertPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.bert.embeddings.word_embeddings.weight
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
"""