@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch ESM model."""
|
""" PyTorch ESM model."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1102,6 +1103,11 @@ class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
|
|||||||
|
|
||||||
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
|
self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
|
||||||
self.lm_head = TFEsmLMHead(config, name="lm_head")
|
self.lm_head = TFEsmLMHead(config, name="lm_head")
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
# Ensure word embeddings are built so that we actually have something to tie
|
||||||
|
with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
|
||||||
|
self.esm.embeddings.word_embeddings.build((None, None))
|
||||||
|
self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
@@ -1211,18 +1217,22 @@ class TFEsmLMHead(Layer):
|
|||||||
|
|
||||||
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
self.layer_norm = LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
||||||
|
|
||||||
self.decoder = Dense(
|
self.decoder = None
|
||||||
config.vocab_size,
|
|
||||||
use_bias=False,
|
|
||||||
kernel_initializer=get_initializer(config.initializer_range),
|
|
||||||
name="decoder",
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
# Separate bias to match the PT model and allow weight cross-loading to work
|
# Separate bias to match the PT model and allow weight cross-loading to work
|
||||||
# Put it in the build so it gets the right name when adding it as a weight
|
# Put it in the build so it gets the right name when adding it as a weight
|
||||||
|
if not self.config.tie_word_embeddings:
|
||||||
|
if self.decoder is not None:
|
||||||
|
raise ValueError("Expected decoder not to be initialized before build when not tying weights!")
|
||||||
|
self.decoder = self.add_weight(
|
||||||
|
"decoder.weight",
|
||||||
|
shape=(self.config.hidden_size, self.config.vocab_size),
|
||||||
|
initializer=get_initializer(self.config.initializer_range),
|
||||||
|
trainable=True,
|
||||||
|
)
|
||||||
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
|
self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
|
||||||
|
|
||||||
def get_bias(self):
|
def get_bias(self):
|
||||||
@@ -1234,8 +1244,7 @@ class TFEsmLMHead(Layer):
|
|||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
# project back to size of vocabulary with bias
|
# project back to size of vocabulary with bias
|
||||||
x = self.decoder(x)
|
x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
|
||||||
x = x + self.bias
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
def test_save_load_after_resize_token_embeddings(self):
|
def test_save_load_after_resize_token_embeddings(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||||
|
if model_class is TFEsmForMaskedLM:
|
||||||
|
# Output embedding test differs from the main test because they're a matrix, not a layer
|
||||||
|
name = model.get_bias()
|
||||||
|
assert isinstance(name, dict)
|
||||||
|
for k, v in name.items():
|
||||||
|
assert isinstance(v, tf.Variable)
|
||||||
|
else:
|
||||||
|
x = model.get_output_embeddings()
|
||||||
|
assert x is None
|
||||||
|
name = model.get_bias()
|
||||||
|
assert name is None
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user