From c4f7eb124b218741d66dd1d86b5d744024a78f6f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Jan 2022 10:42:08 +0000 Subject: [PATCH] add TF glu activation function (#15146) --- src/transformers/activations_tf.py | 17 +++++++++++++++++ tests/test_activations_tf.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/src/transformers/activations_tf.py b/src/transformers/activations_tf.py index 5c40803f25..32c333e46a 100644 --- a/src/transformers/activations_tf.py +++ b/src/transformers/activations_tf.py @@ -69,6 +69,22 @@ def quick_gelu(x): return x * tf.math.sigmoid(coeff * x) +def glu(x, axis=-1): + """ + Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where + the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B). + + Args: + `x`: float Tensor to perform activation + `axis`: dimension across which `x` be split in half + + Returns: + `x` with the GLU activation applied (with its size halved across the dimension `axis`). + """ + a, b = tf.split(x, 2, axis=axis) + return a * tf.math.sigmoid(b) + + if version.parse(tf.version.VERSION) >= version.parse("2.4"): def approximate_gelu_wrap(x): @@ -91,6 +107,7 @@ ACT2FN = { "tanh": tf.keras.activations.tanh, "gelu_fast": gelu_fast, "quick_gelu": quick_gelu, + "glu": glu, } diff --git a/tests/test_activations_tf.py b/tests/test_activations_tf.py index 6f9ef2e4ce..236c37c333 100644 --- a/tests/test_activations_tf.py +++ b/tests/test_activations_tf.py @@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase): get_tf_activation("gelu_new") get_tf_activation("gelu_fast") get_tf_activation("mish") + get_tf_activation("quick_gelu") + get_tf_activation("glu") with self.assertRaises(KeyError): get_tf_activation("bogus") with self.assertRaises(KeyError):