Add common properties input_embeddings and output_embeddings

This commit is contained in:
thomwolf
2019-11-04 12:28:56 +01:00
parent 8a62835577
commit 9b45d0f878
12 changed files with 179 additions and 153 deletions

View File

@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
load_tf_weights = None
base_model_prefix = "distilbert"
def __init__(self, *inputs, **kwargs):
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module):
""" Initialize the weights.
"""
@@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.init_weights()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids,