switch from properties to methods

This commit is contained in:
thomwolf
2019-11-04 15:34:10 +01:00
parent 9b45d0f878
commit 1724cee8c4
12 changed files with 70 additions and 75 deletions

View File

@@ -281,11 +281,10 @@ class XxxModel(XxxPreTrainedModel):
self.init_weights() self.init_weights()
@property @property
def input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -382,8 +381,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -601,13 +601,11 @@ class BertModel(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, value):
def input_embeddings(self, new_embeddings): self.embeddings.word_embeddings = value
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
@@ -753,8 +751,7 @@ class BertForPreTraining(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@@ -829,8 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -289,12 +289,10 @@ class CTRLModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.w return self.w
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.w = new_embeddings self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -454,8 +452,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -421,12 +421,10 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -513,8 +511,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.vocab_projector return self.vocab_projector
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):

View File

@@ -357,12 +357,10 @@ class GPT2Model(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.wte return self.wte
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.wte = new_embeddings self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -519,8 +517,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@@ -623,8 +620,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -360,12 +360,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.tokens_embed return self.tokens_embed
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.tokens_embed = new_embeddings self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -494,8 +492,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
@@ -584,8 +581,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -169,10 +169,11 @@ class RobertaModel(BertModel):
self.embeddings = RobertaEmbeddings(config) self.embeddings = RobertaEmbeddings(config)
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_emebddings = value
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """,
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
@@ -218,8 +219,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,

View File

@@ -639,12 +639,10 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.word_emb return self.word_emb
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.word_emb = new_embeddings self.word_emb = new_embeddings
def backward_compatible(self): def backward_compatible(self):

View File

@@ -87,21 +87,37 @@ class PreTrainedModel(nn.Module):
def base_model(self): def base_model(self):
return getattr(self, self.base_model_prefix, self) return getattr(self, self.base_model_prefix, self)
@property def get_input_embeddings(self):
def input_embeddings(self): """ Get model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
return base_model.input_embeddings if base_model is not self:
return base_model.get_input_embeddings()
else:
raise NotImplementedError
@property def set_input_embeddings(self, value):
def output_embeddings(self): """ Set model's input embeddings
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self):
""" Get model's output embeddings
Return None if the model doesn't have output embeddings
"""
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def tie_weights(self): def tie_weights(self):
""" Make sure we are sharing the input and output embeddings. """ Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead. Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
if self.output_embeddings is not None: output_embeddings = self.get_output_embeddings()
self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings) if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not """ Tie or clone module weights depending of weither we are using TorchScript or not
@@ -150,9 +166,10 @@ class PreTrainedModel(nn.Module):
return model_embeds return model_embeds
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.input_embeddings old_embeddings = self.get_input_embeddings()
self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
return self.input_embeddings self.set_input_embeddings(new_embeddings)
return self.get_input_embeddings()
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Module from a provided token Embedding Module. """ Build a resized Embedding Module from a provided token Embedding Module.

View File

@@ -407,12 +407,10 @@ class XLMModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.embeddings return self.embeddings
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings self.embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -623,8 +621,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,

View File

@@ -611,12 +611,10 @@ class XLNetModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_input_embeddings(self):
def input_embeddings(self):
return self.word_embedding return self.word_embedding
@input_embeddings.setter def set_input_embeddings(self, new_embeddings):
def input_embeddings(self, new_embeddings):
self.word_embedding = new_embeddings self.word_embedding = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
@@ -923,8 +921,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@property def get_output_embeddings(self):
def output_embeddings(self):
return self.lm_loss return self.lm_loss
def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,

View File

@@ -429,6 +429,12 @@ class CommonTestCases:
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size]) [self.model_tester.seq_length, self.model_tester.hidden_size])
def test_debug(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model_embed = model.resize_token_embeddings(config.vocab_size + 10)
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings: if not self.test_resize_embeddings:
@@ -468,9 +474,9 @@ class CommonTestCases:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
self.assertTrue(hasattr(model, 'input_embeddings')) model.get_input_embeddings()
setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10)) model.set_input_embeddings(torch.nn.Embedding(10, 10))
self.assertTrue(hasattr(model, 'output_embeddings')) model.get_output_embeddings()
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript: