From 77ed9fa1a9cfcf3e956d67e598b76991e17a45f0 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 18 Sep 2023 16:58:38 +0200 Subject: [PATCH] [FSMT] Fix non-shared weights (#26187) * Fix non-shared weights * Add tests * Edit tied weights keys --- src/transformers/models/fsmt/modeling_fsmt.py | 8 ++++++-- tests/models/fsmt/test_modeling_fsmt.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 29d78c1269..eb324b9008 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1034,7 +1034,7 @@ def _get_shape(t): FSMT_START_DOCSTRING, ) class FSMTModel(PretrainedFSMTModel): - _tied_weights_keys = ["decoder.embed_tokens.weight"] + _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] def __init__(self, config: FSMTConfig): super().__init__(config) @@ -1055,6 +1055,10 @@ class FSMTModel(PretrainedFSMTModel): def get_decoder(self): return self.decoder + def _tie_weights(self): + self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings()) + self._tie_or_clone_weights(self.decoder.output_projection, self.get_input_embeddings()) + @add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1171,7 +1175,7 @@ class FSMTModel(PretrainedFSMTModel): ) class FSMTForConditionalGeneration(PretrainedFSMTModel): base_model_prefix = "model" - _tied_weights_keys = ["model.decoder.embed_tokens.weight"] + _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] def __init__(self, config: FSMTConfig): super().__init__(config) diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 7a39e82aea..2368e23021 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -271,6 +271,23 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin input_names=["input_ids", "attention_mask"], ) + def test_ensure_weights_are_shared(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + model = FSMTForConditionalGeneration(config) + + # FSMT shares three weights. + # Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors. + self.assertEqual( + len( + { + model.get_output_embeddings().weight.data_ptr(), + model.get_input_embeddings().weight.data_ptr(), + model.base_model.decoder.output_projection.weight.data_ptr(), + } + ), + 1, + ) + @unittest.skip("can't be implemented for FSMT due to dual vocab.") def test_resize_tokens_embeddings(self): pass