From 3f290e6c8403c6a2cf80dce068869793bde49540 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 21 Jan 2021 13:00:11 +0100 Subject: [PATCH] Fix mixed precision in TF models (#9163) * Fix Gelu precision * Fix gelu_fast * Naming * Fix usage and apply style * add TF gelu approximate version * add TF gelu approximate version * add TF gelu approximate version * Apply style * Fix albert * Remove the usage of the Activation layer --- src/transformers/activations_tf.py | 29 ++++++++++++++----- .../models/albert/modeling_tf_albert.py | 2 +- .../models/bert/modeling_tf_bert.py | 2 +- .../models/electra/modeling_tf_electra.py | 2 +- .../longformer/modeling_tf_longformer.py | 2 +- .../models/mpnet/modeling_tf_mpnet.py | 2 +- .../models/roberta/modeling_tf_roberta.py | 2 +- ...tf_{{cookiecutter.lowercase_modelname}}.py | 2 +- 8 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/transformers/activations_tf.py b/src/transformers/activations_tf.py index 3326a6fc61..929dbb310a 100644 --- a/src/transformers/activations_tf.py +++ b/src/transformers/activations_tf.py @@ -15,9 +15,10 @@ import math import tensorflow as tf +from packaging import version -def gelu(x): +def _gelu(x): """ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): @@ -25,12 +26,12 @@ def gelu(x): https://arxiv.org/abs/1606.08415 """ x = tf.convert_to_tensor(x) - cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) + cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) return x * cdf -def gelu_new(x): +def _gelu_new(x): """ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 @@ -56,21 +57,33 @@ def mish(x): def gelu_fast(x): x = tf.convert_to_tensor(x) - coeff1 = tf.cast(7978845608, x.dtype) + coeff1 = tf.cast(0.7978845608, x.dtype) coeff2 = tf.cast(0.044715, x.dtype) return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) +if version.parse(tf.version.VERSION) >= version.parse("2.4"): + + def approximate_gelu_wrap(x): + return tf.keras.activations.gelu(x, approximate=True) + + gelu = tf.keras.activations.gelu + gelu_new = approximate_gelu_wrap +else: + gelu = _gelu + gelu_new = _gelu_new + + ACT2FN = { - "gelu": tf.keras.layers.Activation(gelu), + "gelu": gelu, "relu": tf.keras.activations.relu, "swish": tf.keras.activations.swish, "silu": tf.keras.activations.swish, - "gelu_new": tf.keras.layers.Activation(gelu_new), - "mish": tf.keras.layers.Activation(mish), + "gelu_new": gelu_new, + "mish": mish, "tanh": tf.keras.activations.tanh, - "gelu_fast": tf.keras.layers.Activation(gelu_fast), + "gelu_fast": gelu_fast, } diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index 2c96a3f597..108f55dcf6 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.activation(inputs=hidden_states) + hidden_states = self.activation(hidden_states) hidden_states = self.LayerNorm(inputs=hidden_states) seq_length = shape_list(tensor=hidden_states)[1] hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index d1314b17b4..01d51ddaa5 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index 1fc725162c..a75c406170 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 71fdfc150d..478125bfbf 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index e029acd2db..e1ff0ba701 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 162aa2a197..e8f3e18880 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index ce0cc3a63f..5131c833a8 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer): def call(self, hidden_states): hidden_states = self.dense(inputs=hidden_states) - hidden_states = self.intermediate_act_fn(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states