Add AMP for Albert (#10141)
This commit is contained in:
@@ -62,11 +62,11 @@ TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings
|
||||
# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->ConvBert
|
||||
class TFConvBertEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config: ConvBertConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -83,21 +83,21 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.embedding_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("token_type_embeddings"):
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.type_vocab_size, self.embedding_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.embedding_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
Reference in New Issue
Block a user