Add common properties input_embeddings and output_embeddings
This commit is contained in:
@@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
@@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.vocab_projector,
|
||||
self.distilbert.embeddings.word_embeddings)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.vocab_projector
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
||||
dlbrt_output = self.distilbert(input_ids=input_ids,
|
||||
|
||||
Reference in New Issue
Block a user