From 26ba56ccbd8695c2990804c1933d08ad3b9907e0 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 21 Sep 2023 14:46:05 +0200 Subject: [PATCH] Fix FSMT weight sharing (#26292) --- src/transformers/models/fsmt/modeling_fsmt.py | 5 +++-- tests/models/fsmt/test_modeling_fsmt.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index eb324b9008..1e566b150f 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1056,8 +1056,9 @@ class FSMTModel(PretrainedFSMTModel): return self.decoder def _tie_weights(self): - self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings()) - self._tie_or_clone_weights(self.decoder.output_projection, self.get_input_embeddings()) + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings()) + self._tie_or_clone_weights(self.decoder.output_projection, self.get_input_embeddings()) @add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 2368e23021..f533da7727 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -273,6 +273,8 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_ensure_weights_are_shared(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() + + config.tie_word_embeddings = True model = FSMTForConditionalGeneration(config) # FSMT shares three weights. @@ -288,6 +290,22 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin 1, ) + config.tie_word_embeddings = False + model = FSMTForConditionalGeneration(config) + + # FSMT shares three weights. + # Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors. + self.assertEqual( + len( + { + model.get_output_embeddings().weight.data_ptr(), + model.get_input_embeddings().weight.data_ptr(), + model.base_model.decoder.output_projection.weight.data_ptr(), + } + ), + 2, + ) + @unittest.skip("can't be implemented for FSMT due to dual vocab.") def test_resize_tokens_embeddings(self): pass