Fix mixed precision in TF models (#9163)

* Fix Gelu precision

* Fix gelu_fast

* Naming

* Fix usage and apply style

* add TF gelu approximate version

* add TF gelu approximate version

* add TF gelu approximate version

* Apply style

* Fix albert

* Remove the usage of the Activation layer
This commit is contained in:
Julien Plu
2021-01-21 13:00:11 +01:00
committed by GitHub
parent 248fa1ae72
commit 3f290e6c84
8 changed files with 28 additions and 15 deletions

View File

@@ -15,9 +15,10 @@
import math import math
import tensorflow as tf import tensorflow as tf
from packaging import version
def gelu(x): def _gelu(x):
""" """
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
@@ -25,12 +26,12 @@ def gelu(x):
https://arxiv.org/abs/1606.08415 https://arxiv.org/abs/1606.08415
""" """
x = tf.convert_to_tensor(x) x = tf.convert_to_tensor(x)
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0))) cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
return x * cdf return x * cdf
def gelu_new(x): def _gelu_new(x):
""" """
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
@@ -56,21 +57,33 @@ def mish(x):
def gelu_fast(x): def gelu_fast(x):
x = tf.convert_to_tensor(x) x = tf.convert_to_tensor(x)
coeff1 = tf.cast(7978845608, x.dtype) coeff1 = tf.cast(0.7978845608, x.dtype)
coeff2 = tf.cast(0.044715, x.dtype) coeff2 = tf.cast(0.044715, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
def approximate_gelu_wrap(x):
return tf.keras.activations.gelu(x, approximate=True)
gelu = tf.keras.activations.gelu
gelu_new = approximate_gelu_wrap
else:
gelu = _gelu
gelu_new = _gelu_new
ACT2FN = { ACT2FN = {
"gelu": tf.keras.layers.Activation(gelu), "gelu": gelu,
"relu": tf.keras.activations.relu, "relu": tf.keras.activations.relu,
"swish": tf.keras.activations.swish, "swish": tf.keras.activations.swish,
"silu": tf.keras.activations.swish, "silu": tf.keras.activations.swish,
"gelu_new": tf.keras.layers.Activation(gelu_new), "gelu_new": gelu_new,
"mish": tf.keras.layers.Activation(mish), "mish": mish,
"tanh": tf.keras.activations.tanh, "tanh": tf.keras.activations.tanh,
"gelu_fast": tf.keras.layers.Activation(gelu_fast), "gelu_fast": gelu_fast,
} }

View File

@@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.activation(inputs=hidden_states) hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states) hidden_states = self.LayerNorm(inputs=hidden_states)
seq_length = shape_list(tensor=hidden_states)[1] seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])

View File

@@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states

View File

@@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states) hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states return hidden_states