switch from properties to methods
This commit is contained in:
@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
self.embeddings.word_embeddings = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, value):
|
||||||
def input_embeddings(self, new_embeddings):
|
self.embeddings.word_embeddings = value
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
""" Prunes heads of the model.
|
""" Prunes heads of the model.
|
||||||
@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.w
|
return self.w
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.w = new_embeddings
|
self.w = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
self.embeddings.word_embeddings = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.vocab_projector
|
return self.vocab_projector
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
|
||||||
|
|||||||
@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.wte
|
return self.wte
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.wte = new_embeddings
|
self.wte = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.tokens_embed
|
return self.tokens_embed
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.tokens_embed = new_embeddings
|
self.tokens_embed = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -169,10 +169,11 @@ class RobertaModel(BertModel):
|
|||||||
self.embeddings = RobertaEmbeddings(config)
|
self.embeddings = RobertaEmbeddings(config)
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_emebddings = value
|
||||||
|
|
||||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
|
||||||
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
|
||||||
@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
|
|||||||
@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.word_emb
|
return self.word_emb
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.word_emb = new_embeddings
|
self.word_emb = new_embeddings
|
||||||
|
|
||||||
def backward_compatible(self):
|
def backward_compatible(self):
|
||||||
|
|||||||
@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
|
|||||||
def base_model(self):
|
def base_model(self):
|
||||||
return getattr(self, self.base_model_prefix, self)
|
return getattr(self, self.base_model_prefix, self)
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
""" Get model's input embeddings
|
||||||
|
"""
|
||||||
base_model = getattr(self, self.base_model_prefix, self)
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
return base_model.input_embeddings
|
if base_model is not self:
|
||||||
|
return base_model.get_input_embeddings()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
def set_input_embeddings(self, value):
|
||||||
def output_embeddings(self):
|
""" Set model's input embeddings
|
||||||
|
"""
|
||||||
|
base_model = getattr(self, self.base_model_prefix, self)
|
||||||
|
if base_model is not self:
|
||||||
|
base_model.set_input_embeddings(value)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
""" Get model's output embeddings
|
||||||
|
Return None if the model doesn't have output embeddings
|
||||||
|
"""
|
||||||
return None # Overwrite for models with output embeddings
|
return None # Overwrite for models with output embeddings
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
""" Make sure we are sharing the input and output embeddings.
|
""" Make sure we are sharing the input and output embeddings.
|
||||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||||
"""
|
"""
|
||||||
if self.output_embeddings is not None:
|
output_embeddings = self.get_output_embeddings()
|
||||||
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings)
|
if output_embeddings is not None:
|
||||||
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|
||||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||||
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
|
|||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
old_embeddings = self.input_embeddings
|
old_embeddings = self.get_input_embeddings()
|
||||||
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||||
return self.input_embeddings
|
self.set_input_embeddings(new_embeddings)
|
||||||
|
return self.get_input_embeddings()
|
||||||
|
|
||||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||||
|
|||||||
@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.embeddings
|
return self.embeddings
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.embeddings = new_embeddings
|
self.embeddings = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
|
|||||||
@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_input_embeddings(self):
|
||||||
def input_embeddings(self):
|
|
||||||
return self.word_embedding
|
return self.word_embedding
|
||||||
|
|
||||||
@input_embeddings.setter
|
def set_input_embeddings(self, new_embeddings):
|
||||||
def input_embeddings(self, new_embeddings):
|
|
||||||
self.word_embedding = new_embeddings
|
self.word_embedding = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
def get_output_embeddings(self):
|
||||||
def output_embeddings(self):
|
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
|
||||||
|
|||||||
@@ -429,6 +429,12 @@ class CommonTestCases:
|
|||||||
list(hidden_states[0].shape[-2:]),
|
list(hidden_states[0].shape[-2:]),
|
||||||
[self.model_tester.seq_length, self.model_tester.hidden_size])
|
[self.model_tester.seq_length, self.model_tester.hidden_size])
|
||||||
|
|
||||||
|
def test_debug(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model_embed = model.resize_token_embeddings(config.vocab_size + 10)
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if not self.test_resize_embeddings:
|
if not self.test_resize_embeddings:
|
||||||
@@ -468,9 +474,9 @@ class CommonTestCases:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
self.assertTrue(hasattr(model, 'input_embeddings'))
|
model.get_input_embeddings()
|
||||||
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10))
|
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||||
self.assertTrue(hasattr(model, 'output_embeddings'))
|
model.get_output_embeddings()
|
||||||
|
|
||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
if not self.test_torchscript:
|
if not self.test_torchscript:
|
||||||
|
|||||||
Reference in New Issue
Block a user