Fix TF template (#9234)
This commit is contained in:
@@ -310,18 +310,22 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
equation="abc,cd->abd",
|
||||||
|
output_shape=(None, config.intermediate_size),
|
||||||
|
bias_axes="d",
|
||||||
|
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||||
|
name="dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(config.hidden_act, str):
|
if isinstance(config.hidden_act, str):
|
||||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
|
||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(inputs=hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -331,16 +335,20 @@ class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
equation="abc,cd->abd",
|
||||||
|
bias_axes="d",
|
||||||
|
output_shape=(None, config.hidden_size),
|
||||||
|
kernel_initializer=get_initializer(config.initializer_range),
|
||||||
|
name="dense",
|
||||||
)
|
)
|
||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, hidden_states, input_tensor, training=False):
|
def call(self, hidden_states, input_tensor, training=False):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(inputs=hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user