From baf7e5c927744122c89ab1270c6c312541c7eb41 Mon Sep 17 00:00:00 2001 From: Abdi <48970896+AbdiHaryadi@users.noreply.github.com> Date: Mon, 5 Aug 2024 21:15:36 +0800 Subject: [PATCH] Persist embedding type of BART and mBART models after resize (#32242) * fix: persist embedding type of MBartConditonalGeneration after resize * fix: persist embedding type of BartConditonalGeneration after resize --- src/transformers/models/bart/modeling_bart.py | 3 ++- src/transformers/models/mbart/modeling_mbart.py | 3 ++- tests/models/bart/test_modeling_bart.py | 12 ++++++++++++ tests/models/mbart/test_modeling_mbart.py | 12 ++++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 5727330329..fa928d05ca 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1431,7 +1431,8 @@ class BartModel(BartPreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = BartEncoder(config, self.shared) self.decoder = BartDecoder(config, self.shared) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0782c4d122..6cad7b08f9 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1271,7 +1271,8 @@ class MBartModel(MBartPreTrainedModel): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) self.encoder = MBartEncoder(config, self.shared) self.decoder = MBartDecoder(config, self.shared) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 20d8e3911d..61a0aa9091 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -518,6 +518,18 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin def test_load_save_without_tied_weights(self): pass + def test_resize_embeddings_persists_embeddings_type(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + + config.scale_embedding = True + model = BartForConditionalGeneration(config) + old_type = type(model.model.decoder.embed_tokens) + + model.resize_token_embeddings(new_num_tokens=config.vocab_size) + + new_type = type(model.model.decoder.embed_tokens) + self.assertIs(old_type, new_type) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 4c0bf291c1..5a8263e119 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -375,6 +375,18 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_load_save_without_tied_weights(self): pass + def test_resize_embeddings_persists_embeddings_type(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + + config.scale_embedding = True + model = MBartForConditionalGeneration(config) + old_type = type(model.model.decoder.embed_tokens) + + model.resize_token_embeddings(new_num_tokens=config.vocab_size) + + new_type = type(model.model.decoder.embed_tokens) + self.assertIs(old_type, new_type) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""