Fix TF CTRL model naming (#6134)
This commit is contained in:
@@ -141,11 +141,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
return outputs
|
||||
|
||||
|
||||
def point_wise_feed_forward_network(d_model_size, dff, name=""):
|
||||
return tf.keras.Sequential(
|
||||
[tf.keras.layers.Dense(dff, activation="relu", name="0"), tf.keras.layers.Dense(d_model_size, name="2")],
|
||||
name="ffn",
|
||||
)
|
||||
class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, d_model_size, dff, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")
|
||||
self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")
|
||||
|
||||
def call(self, inputs, trainable=False):
|
||||
dense_0_output = self.dense_0(inputs)
|
||||
dense_2_output = self.dense_2(dense_0_output)
|
||||
|
||||
return dense_2_output
|
||||
|
||||
|
||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
@@ -153,7 +160,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
|
||||
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
|
||||
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
|
||||
|
||||
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
||||
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
|
||||
|
||||
Reference in New Issue
Block a user