Add common properties input_embeddings and output_embeddings

This commit is contained in:
thomwolf
2019-11-04 12:28:56 +01:00
parent 8a62835577
commit 9b45d0f878
12 changed files with 179 additions and 153 deletions

View File

@@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
@property
def input_embeddings(self):
return self.embeddings.word_embeddings
@input_embeddings.setter
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel):
super(XxxForMaskedLM, self).__init__(config)
self.transformer = XxxModel(config)
self.cls = XxxOnlyMLMHead(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.transformer.embeddings.word_embeddings)
@property
def output_embeddings(self):
return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):