add TF glu activation function (#15146)
This commit is contained in:
@@ -69,6 +69,22 @@ def quick_gelu(x):
|
||||
return x * tf.math.sigmoid(coeff * x)
|
||||
|
||||
|
||||
def glu(x, axis=-1):
|
||||
"""
|
||||
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
|
||||
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
|
||||
|
||||
Args:
|
||||
`x`: float Tensor to perform activation
|
||||
`axis`: dimension across which `x` be split in half
|
||||
|
||||
Returns:
|
||||
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
|
||||
"""
|
||||
a, b = tf.split(x, 2, axis=axis)
|
||||
return a * tf.math.sigmoid(b)
|
||||
|
||||
|
||||
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
|
||||
|
||||
def approximate_gelu_wrap(x):
|
||||
@@ -91,6 +107,7 @@ ACT2FN = {
|
||||
"tanh": tf.keras.activations.tanh,
|
||||
"gelu_fast": gelu_fast,
|
||||
"quick_gelu": quick_gelu,
|
||||
"glu": glu,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase):
|
||||
get_tf_activation("gelu_new")
|
||||
get_tf_activation("gelu_fast")
|
||||
get_tf_activation("mish")
|
||||
get_tf_activation("quick_gelu")
|
||||
get_tf_activation("glu")
|
||||
with self.assertRaises(KeyError):
|
||||
get_tf_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
|
||||
Reference in New Issue
Block a user