From 725f4ad1ccad4e1aeb309688706b56713070334b Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Thu, 15 Feb 2024 04:39:01 +0800 Subject: [PATCH] Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948) * Add tie_weights() to LM heads and set bias in set_output_embeddings() The bias were not tied correctly in some LM heads, and this change should fix that. * Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin * Adding _tie_weights() to MPNet and Vilt * Skip test for low cpu mem usage for Deta/DeformableDetr since they cannot init on meta device * Rename to test name to save_load to match the convention --- src/transformers/models/bert/modeling_bert.py | 6 ++++++ .../models/big_bird/modeling_big_bird.py | 6 ++++++ .../models/blip/modeling_blip_text.py | 4 ++++ src/transformers/models/ernie/modeling_ernie.py | 6 ++++++ .../models/layoutlm/modeling_layoutlm.py | 4 ++++ .../models/markuplm/modeling_markuplm.py | 3 +++ .../megatron_bert/modeling_megatron_bert.py | 6 ++++++ src/transformers/models/mpnet/modeling_mpnet.py | 4 ++++ src/transformers/models/mra/modeling_mra.py | 4 ++++ src/transformers/models/nezha/modeling_nezha.py | 5 +++++ .../nystromformer/modeling_nystromformer.py | 4 ++++ .../models/qdqbert/modeling_qdqbert.py | 5 +++++ .../models/roc_bert/modeling_roc_bert.py | 6 ++++++ src/transformers/models/tapas/modeling_tapas.py | 4 ++++ src/transformers/models/vilt/modeling_vilt.py | 4 ++++ .../models/visual_bert/modeling_visual_bert.py | 4 ++++ src/transformers/models/yoso/modeling_yoso.py | 4 ++++ .../test_modeling_deformable_detr.py | 4 ++++ tests/models/deta/test_modeling_deta.py | 4 ++++ tests/test_modeling_common.py | 17 +++++++++++++++++ 20 files changed, 104 insertions(+) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index c6764c771e..3eff144700 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -692,6 +692,9 @@ class BertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1062,6 +1065,7 @@ class BertForPreTraining(BertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1171,6 +1175,7 @@ class BertLMHeadModel(BertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1324,6 +1329,7 @@ class BertForMaskedLM(BertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 008985f760..6e3af915cf 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1707,6 +1707,9 @@ class BigBirdLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -2266,6 +2269,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -2378,6 +2382,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -2519,6 +2524,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 353c0f486a..f9ae08b667 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -523,6 +523,9 @@ class BlipTextLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -816,6 +819,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias def forward( self, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 291ab6c54d..1a1e49dcbf 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -608,6 +608,9 @@ class ErnieLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -995,6 +998,7 @@ class ErnieForPreTraining(ErniePreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1109,6 +1113,7 @@ class ErnieForCausalLM(ErniePreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1269,6 +1274,7 @@ class ErnieForMaskedLM(ErniePreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index c2ecede73d..70d11573d9 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -589,6 +589,9 @@ class LayoutLMLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -869,6 +872,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 24ca0c4972..8d95bcc0c1 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -318,6 +318,9 @@ class MarkupLMLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 9111f937bc..0fd9127bab 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -659,6 +659,9 @@ class MegatronBertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1023,6 +1026,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1132,6 +1136,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1290,6 +1295,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 86194607e2..43cfaa5e69 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -587,6 +587,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -659,6 +660,9 @@ class MPNetLMHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 7e81f2a46c..d11c255771 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -820,6 +820,9 @@ class MraLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1053,6 +1056,7 @@ class MraForMaskedLM(MraPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 918a10b275..8fc2041e93 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -679,6 +679,9 @@ class NezhaLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1044,6 +1047,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1152,6 +1156,7 @@ class NezhaForMaskedLM(NezhaPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 950f8d27fa..1bba9fb1f8 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -428,6 +428,9 @@ class NystromformerLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -666,6 +669,7 @@ class NystromformerForMaskedLM(NystromformerPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 33d6d6b208..5e7704c77c 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -683,6 +683,9 @@ class QDQBertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1024,6 +1027,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1190,6 +1194,7 @@ class QDQBertForMaskedLM(QDQBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index f3de92fed3..ded234b71c 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -744,6 +744,9 @@ class RoCBertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1090,6 +1093,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1282,6 +1286,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def forward( @@ -1419,6 +1424,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 1e7a4372bb..1ee233ea9d 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -729,6 +729,9 @@ class TapasLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1008,6 +1011,7 @@ class TapasForMaskedLM(TapasPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 9ffa9fff01..5e53d4332b 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -896,6 +896,7 @@ class ViltForMaskedLM(ViltPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.mlm_score.decoder = new_embeddings + self.mlm_score.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1042,6 +1043,9 @@ class ViltMLMHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, x): x = self.transform(x) x = self.decoder(x) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index f8a146ed2c..f81f7b04c8 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -499,6 +499,9 @@ class VisualBertLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -879,6 +882,7 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 4e08b999ad..9c0636340d 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -619,6 +619,9 @@ class YosoLMPredictionHead(nn.Module): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -857,6 +860,7 @@ class YosoForMaskedLM(YosoPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 336f2437c4..2d5a0deec3 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -564,6 +564,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization") + def test_save_load_low_cpu_mem_usage(self): + pass + def test_two_stage_training(self): model_class = DeformableDetrForObjectDetection config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 3a3a957dd0..ffebfd38d0 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -520,6 +520,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization") + def test_save_load_low_cpu_mem_usage(self): + pass + TOLERANCE = 1e-4 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 32f6abcbe3..dfe613fa1f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -435,6 +435,23 @@ class ModelTesterMixin: max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_save_load_low_cpu_mem_usage(self): + with tempfile.TemporaryDirectory() as tmpdirname: + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model_to_save = model_class(config) + + model_to_save.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + low_cpu_mem_usage=True, + ) + + # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are + # any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error. + model.to(torch_device) + def test_fast_init_context_manager(self): # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ class MyClass(PreTrainedModel):