From 9b45d0f8787a19570a04732b0875c94951870766 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 4 Nov 2019 12:28:56 +0100 Subject: [PATCH] Add common properties input_embeddings and output_embeddings --- templates/adding_a_new_model/modeling_xxx.py | 22 ++-- transformers/modeling_bert.py | 30 ++--- transformers/modeling_ctrl.py | 17 +-- transformers/modeling_distilbert.py | 23 ++-- transformers/modeling_gpt2.py | 28 ++--- transformers/modeling_openai.py | 28 ++--- transformers/modeling_roberta.py | 13 +- transformers/modeling_transfo_xl.py | 8 +- transformers/modeling_utils.py | 122 ++++++++++++------- transformers/modeling_xlm.py | 16 +-- transformers/modeling_xlnet.py | 16 +-- transformers/tests/modeling_common_test.py | 9 ++ 12 files changed, 179 insertions(+), 153 deletions(-) diff --git a/templates/adding_a_new_model/modeling_xxx.py b/templates/adding_a_new_model/modeling_xxx.py index 7e2ba9dfb5..5335439dfb 100644 --- a/templates/adding_a_new_model/modeling_xxx.py +++ b/templates/adding_a_new_model/modeling_xxx.py @@ -280,12 +280,14 @@ class XxxModel(XxxPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.embeddings.word_embeddings - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - self.embeddings.word_embeddings = new_embeddings + @property + def input_embeddings(self): return self.embeddings.word_embeddings + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -376,17 +378,13 @@ class XxxForMaskedLM(XxxPreTrainedModel): super(XxxForMaskedLM, self).__init__(config) self.transformer = XxxModel(config) - self.cls = XxxOnlyMLMHead(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.cls.predictions.decoder, - self.transformer.embeddings.word_embeddings) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None): diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index ff13a45bad..a920aa86d3 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -601,12 +601,14 @@ class BertModel(BertPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.embeddings.word_embeddings - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - self.embeddings.word_embeddings = new_embeddings + @property + def input_embeddings(self): return self.embeddings.word_embeddings + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -750,14 +752,10 @@ class BertForPreTraining(BertPreTrainedModel): self.cls = BertPreTrainingHeads(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.cls.predictions.decoder, - self.bert.embeddings.word_embeddings) + @property + def output_embeddings(self): + return self.cls.predictions.decoder def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None, next_sentence_label=None): @@ -830,14 +828,10 @@ class BertForMaskedLM(BertPreTrainedModel): self.cls = BertOnlyMLMHead(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.cls.predictions.decoder, - self.bert.embeddings.word_embeddings) + @property + def output_embeddings(self): + return self.cls.predictions.decoder def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index 55e64d318b..c588dc30ba 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -289,10 +289,14 @@ class CTRLModel(CTRLPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - self.w = self._get_resized_embeddings(self.w, new_num_tokens) + @property + def input_embeddings(self): return self.w + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.w = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -449,13 +453,10 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=True) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head, self.transformer.w) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): diff --git a/transformers/modeling_distilbert.py b/transformers/modeling_distilbert.py index d3b4ccff5d..7365d1a7dc 100644 --- a/transformers/modeling_distilbert.py +++ b/transformers/modeling_distilbert.py @@ -334,9 +334,6 @@ class DistilBertPreTrainedModel(PreTrainedModel): load_tf_weights = None base_model_prefix = "distilbert" - def __init__(self, *inputs, **kwargs): - super(DistilBertPreTrainedModel, self).__init__(*inputs, **kwargs) - def _init_weights(self, module): """ Initialize the weights. """ @@ -424,12 +421,14 @@ class DistilBertModel(DistilBertPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - old_embeddings = self.embeddings.word_embeddings - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - self.embeddings.word_embeddings = new_embeddings + @property + def input_embeddings(self): return self.embeddings.word_embeddings + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -511,16 +510,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.init_weights() - self.tie_weights() self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.vocab_projector, - self.distilbert.embeddings.word_embeddings) + @property + def output_embeddings(self): + return self.vocab_projector def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None): dlbrt_output = self.distilbert(input_ids=input_ids, diff --git a/transformers/modeling_gpt2.py b/transformers/modeling_gpt2.py index 0b5b83aa75..9f48152884 100644 --- a/transformers/modeling_gpt2.py +++ b/transformers/modeling_gpt2.py @@ -357,10 +357,14 @@ class GPT2Model(GPT2PreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - self.wte = self._get_resized_embeddings(self.wte, new_num_tokens) + @property + def input_embeddings(self): return self.wte + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.wte = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -514,14 +518,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head, - self.transformer.wte) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): @@ -622,14 +622,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): self.multiple_choice_head = SequenceSummary(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head, - self.transformer.wte) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, lm_labels=None, mc_labels=None): diff --git a/transformers/modeling_openai.py b/transformers/modeling_openai.py index 52f3b7db72..9e25b9cfb4 100644 --- a/transformers/modeling_openai.py +++ b/transformers/modeling_openai.py @@ -360,10 +360,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens) + @property + def input_embeddings(self): return self.tokens_embed + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.tokens_embed = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -489,14 +493,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head, - self.transformer.tokens_embed) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): @@ -583,14 +583,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.multiple_choice_head = SequenceSummary(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head, - self.transformer.tokens_embed) + @property + def output_embeddings(self): + return self.lm_head def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, mc_token_ids=None, lm_labels=None, mc_labels=None): diff --git a/transformers/modeling_roberta.py b/transformers/modeling_roberta.py index c155856be7..81216c93d4 100644 --- a/transformers/modeling_roberta.py +++ b/transformers/modeling_roberta.py @@ -169,6 +169,10 @@ class RobertaModel(BertModel): self.embeddings = RobertaEmbeddings(config) self.init_weights() + @property + def input_embeddings(self): + return self.embeddings.word_embeddings + @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) @@ -213,13 +217,10 @@ class RobertaForMaskedLM(BertPreTrainedModel): self.lm_head = RobertaLMHead(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - self._tie_or_clone_weights(self.lm_head.decoder, self.roberta.embeddings.word_embeddings) + @property + def output_embeddings(self): + return self.lm_head.decoder def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None): diff --git a/transformers/modeling_transfo_xl.py b/transformers/modeling_transfo_xl.py index 6d430e1804..0bc7fadd77 100644 --- a/transformers/modeling_transfo_xl.py +++ b/transformers/modeling_transfo_xl.py @@ -639,9 +639,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): + @property + def input_embeddings(self): return self.word_emb + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.word_emb = new_embeddings + def backward_compatible(self): self.sample_softmax = -1 @@ -826,7 +831,6 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val) self.init_weights() - self.tie_weights() def tie_weights(self): """ diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 03490630ed..3c2d8e199d 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -83,6 +83,77 @@ class PreTrainedModel(nn.Module): # Save config in model self.config = config + @property + def base_model(self): + return getattr(self, self.base_model_prefix, self) + + @property + def input_embeddings(self): + base_model = getattr(self, self.base_model_prefix, self) + return base_model.input_embeddings + + @property + def output_embeddings(self): + return None # Overwrite for models with output embeddings + + def tie_weights(self): + """ Make sure we are sharing the input and output embeddings. + Export to TorchScript can't handle parameter sharing so we are cloning them instead. + """ + if self.output_embeddings is not None: + self._tie_or_clone_weights(self.output_embeddings, self.input_embeddings) + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """ Tie or clone module weights depending of weither we are using TorchScript or not + """ + if self.config.torchscript: + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) + else: + output_embeddings.weight = input_embeddings.weight + + if hasattr(output_embeddings, 'bias') and output_embeddings.bias is not None: + output_embeddings.bias.data = torch.nn.functional.pad( + output_embeddings.bias.data, + (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), + 'constant', + 0 + ) + if hasattr(output_embeddings, 'out_features') and hasattr(input_embeddings, 'num_embeddings'): + output_embeddings.out_features = input_embeddings.num_embeddings + + def resize_token_embeddings(self, new_num_tokens=None): + """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. + Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + + new_num_tokens: (`optional`) int: + New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. + If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. + + Return: ``torch.nn.Embeddings`` + Pointer to the input tokens Embeddings Module of the model + """ + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + model_embeds = base_model._resize_token_embeddings(new_num_tokens) + if new_num_tokens is None: + return model_embeds + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + base_model.vocab_size = new_num_tokens + + # Tie weights again if needed + if hasattr(self, 'tie_weights'): + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.input_embeddings + self.input_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + return self.input_embeddings + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end @@ -117,50 +188,6 @@ class PreTrainedModel(nn.Module): return new_embeddings - def _tie_or_clone_weights(self, first_module, second_module): - """ Tie or clone module weights depending of weither we are using TorchScript or not - """ - if self.config.torchscript: - first_module.weight = nn.Parameter(second_module.weight.clone()) - else: - first_module.weight = second_module.weight - - if hasattr(first_module, 'bias') and first_module.bias is not None: - first_module.bias.data = torch.nn.functional.pad( - first_module.bias.data, - (0, first_module.weight.shape[0] - first_module.bias.shape[0]), - 'constant', - 0 - ) - - def resize_token_embeddings(self, new_num_tokens=None): - """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. - Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. - - Arguments: - - new_num_tokens: (`optional`) int: - New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. - If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model. - - Return: ``torch.nn.Embeddings`` - Pointer to the input tokens Embeddings Module of the model - """ - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed - model_embeds = base_model._resize_token_embeddings(new_num_tokens) - if new_num_tokens is None: - return model_embeds - - # Update base model and current model config - self.config.vocab_size = new_num_tokens - base_model.vocab_size = new_num_tokens - - # Tie weights again if needed - if hasattr(self, 'tie_weights'): - self.tie_weights() - - return model_embeds - def init_weights(self): """ Initialize and prunes weights if needed. """ # Initialize weights @@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module): if self.config.pruned_heads: self.prune_heads(self.config.pruned_heads) + # Tie weights if needed + self.tie_weights() + def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. @@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module): heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. """ - base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed - # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads for layer, heads in heads_to_prune.items(): union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON - base_model._prune_heads(heads_to_prune) + self.base_model._prune_heads(heads_to_prune) def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index 166b98de63..396632d55e 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -407,10 +407,14 @@ class XLMModel(XLMPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - self.embeddings = self._get_resized_embeddings(self.embeddings, new_num_tokens) + @property + def input_embeddings(self): return self.embeddings + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -618,12 +622,10 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): self.pred_layer = XLMPredLayer(config) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the embeddings - """ - self._tie_or_clone_weights(self.pred_layer.proj, self.transformer.embeddings) + @property + def output_embeddings(self): + return self.pred_layer.proj def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, lengths=None, cache=None, head_mask=None, labels=None): diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index e191ebadd0..3173616eb8 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -611,10 +611,14 @@ class XLNetModel(XLNetPreTrainedModel): self.init_weights() - def _resize_token_embeddings(self, new_num_tokens): - self.word_embedding = self._get_resized_embeddings(self.word_embedding, new_num_tokens) + @property + def input_embeddings(self): return self.word_embedding + @input_embeddings.setter + def input_embeddings(self, new_embeddings): + self.word_embedding = new_embeddings + def _prune_heads(self, heads_to_prune): raise NotImplementedError @@ -918,12 +922,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.init_weights() - self.tie_weights() - def tie_weights(self): - """ Make sure we are sharing the embeddings - """ - self._tie_or_clone_weights(self.lm_loss, self.transformer.word_embedding) + @property + def output_embeddings(self): + return self.lm_loss def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, token_type_ids=None, input_mask=None, head_mask=None, labels=None): diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index 008d7c0d51..300b019dfb 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -463,6 +463,15 @@ class CommonTestCases: self.assertTrue(models_equal) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertTrue(hasattr(model, 'input_embeddings')) + setattr(model, 'input_embeddings', torch.nn.Embedding(10, 10)) + self.assertTrue(hasattr(model, 'output_embeddings')) + def test_tie_model_weights(self): if not self.test_torchscript: return