Add common properties input_embeddings and output_embeddings
This commit is contained in:
@@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||
@property
|
||||
def input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
@input_embeddings.setter
|
||||
def input_embeddings(self, new_embeddings):
|
||||
self.tokens_embed = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
labels=None):
|
||||
@@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head,
|
||||
self.transformer.tokens_embed)
|
||||
@property
|
||||
def output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
mc_token_ids=None, lm_labels=None, mc_labels=None):
|
||||
|
||||
Reference in New Issue
Block a user