From 161a6461db3c673672edd5b994848b0d50db1b67 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 21 Dec 2020 13:52:16 +0100 Subject: [PATCH] Fix TF template (#9234) --- ...tf_{{cookiecutter.lowercase_modelname}}.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 5c8ffbfc41..5d505371d0 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -310,18 +310,22 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + output_shape=(None, config.intermediate_size), + bias_axes="d", + kernel_initializer=get_initializer(initializer_range=config.initializer_range), + name="dense", ) if isinstance(config.hidden_act, str): - self.intermediate_act_fn = get_tf_activation(config.hidden_act) + self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act) else: self.intermediate_act_fn = config.hidden_act def call(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(inputs=hidden_states) return hidden_states @@ -331,16 +335,20 @@ class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.dense = tf.keras.layers.Dense( - config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + self.dense = tf.keras.layers.experimental.EinsumDense( + equation="abc,cd->abd", + bias_axes="d", + output_shape=(None, config.hidden_size), + kernel_initializer=get_initializer(config.initializer_range), + name="dense", ) self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") - self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) def call(self, hidden_states, input_tensor, training=False): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states, training=training) - hidden_states = self.LayerNorm(hidden_states + input_tensor) + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) return hidden_states