From 6dc0a849b7f0d708fd0c8e3cc3b3c407d70335a1 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 20 Apr 2023 15:50:31 +0100 Subject: [PATCH] Fix weight tying in TF-ESM (#22839) Fix weight tying in ESM --- .../models/esm/modeling_tf_esm.py | 25 +++++++++++++------ tests/models/esm/test_modeling_tf_esm.py | 18 +++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/esm/modeling_tf_esm.py b/src/transformers/models/esm/modeling_tf_esm.py index 2bb25ba94d..0ed988e4c8 100644 --- a/src/transformers/models/esm/modeling_tf_esm.py +++ b/src/transformers/models/esm/modeling_tf_esm.py @@ -14,6 +14,7 @@ # limitations under the License. """ PyTorch ESM model.""" +import os from typing import Optional, Tuple, Union import numpy as np @@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") self.lm_head = TFEsmLMHead(config, name="lm_head") + if config.tie_word_embeddings: + # Ensure word embeddings are built so that we actually have something to tie + with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")): + self.esm.embeddings.word_embeddings.build((None, None)) + self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0] def get_output_embeddings(self): return self.lm_head.decoder @@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer): self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") - self.decoder = Dense( - config.vocab_size, - use_bias=False, - kernel_initializer=get_initializer(config.initializer_range), - name="decoder", - ) + self.decoder = None self.config = config def build(self, input_shape): super().build(input_shape) # Separate bias to match the PT model and allow weight cross-loading to work # Put it in the build so it gets the right name when adding it as a weight + if not self.config.tie_word_embeddings: + if self.decoder is not None: + raise ValueError("Expected decoder not to be initialized before build when not tying weights!") + self.decoder = self.add_weight( + "decoder.weight", + shape=(self.config.hidden_size, self.config.vocab_size), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + ) self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) def get_bias(self): @@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer): x = self.layer_norm(x) # project back to size of vocabulary with bias - x = self.decoder(x) - x = x + self.bias + x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias return x diff --git a/tests/models/esm/test_modeling_tf_esm.py b/tests/models/esm/test_modeling_tf_esm.py index 663642bde2..dc9d430d07 100644 --- a/tests/models/esm/test_modeling_tf_esm.py +++ b/tests/models/esm/test_modeling_tf_esm.py @@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase) def test_save_load_after_resize_token_embeddings(self): pass + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + if model_class is TFEsmForMaskedLM: + # Output embedding test differs from the main test because they're a matrix, not a layer + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + @require_tf class TFEsmModelIntegrationTest(unittest.TestCase):