updating tests
This commit is contained in:
@@ -617,6 +617,7 @@ class BertModel(BertPreTrainedModel):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
@@ -758,11 +759,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.bert.embeddings.word_embeddings.weight
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||
next_sentence_label=None, head_mask=None):
|
||||
@@ -864,11 +862,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.bert.embeddings.word_embeddings.weight
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
self._tie_or_clone_weights(self.cls.predictions.decoder,
|
||||
self.bert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user