Fix TF CTRL model naming (#6134)

This commit is contained in:
Julien Plu
2020-07-29 18:20:00 +02:00
committed by GitHub
parent 641b873c13
commit fc64559c45
2 changed files with 17 additions and 10 deletions

View File

@@ -141,11 +141,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return outputs
def point_wise_feed_forward_network(d_model_size, dff, name=""):
return tf.keras.Sequential(
[tf.keras.layers.Dense(dff, activation="relu", name="0"), tf.keras.layers.Dense(d_model_size, name="2")],
name="ffn",
)
class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, dff, **kwargs):
super().__init__(**kwargs)
self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")
self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")
def call(self, inputs, trainable=False):
dense_0_output = self.dense_0(inputs)
dense_2_output = self.dense_2(dense_0_output)
return dense_2_output
class TFEncoderLayer(tf.keras.layers.Layer):
@@ -153,7 +160,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")