add TF glu activation function (#15146)
This commit is contained in:
@@ -69,6 +69,22 @@ def quick_gelu(x):
|
|||||||
return x * tf.math.sigmoid(coeff * x)
|
return x * tf.math.sigmoid(coeff * x)
|
||||||
|
|
||||||
|
|
||||||
|
def glu(x, axis=-1):
|
||||||
|
"""
|
||||||
|
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
|
||||||
|
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
`x`: float Tensor to perform activation
|
||||||
|
`axis`: dimension across which `x` be split in half
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
|
||||||
|
"""
|
||||||
|
a, b = tf.split(x, 2, axis=axis)
|
||||||
|
return a * tf.math.sigmoid(b)
|
||||||
|
|
||||||
|
|
||||||
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
|
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
|
||||||
|
|
||||||
def approximate_gelu_wrap(x):
|
def approximate_gelu_wrap(x):
|
||||||
@@ -91,6 +107,7 @@ ACT2FN = {
|
|||||||
"tanh": tf.keras.activations.tanh,
|
"tanh": tf.keras.activations.tanh,
|
||||||
"gelu_fast": gelu_fast,
|
"gelu_fast": gelu_fast,
|
||||||
"quick_gelu": quick_gelu,
|
"quick_gelu": quick_gelu,
|
||||||
|
"glu": glu,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class TestTFActivations(unittest.TestCase):
|
|||||||
get_tf_activation("gelu_new")
|
get_tf_activation("gelu_new")
|
||||||
get_tf_activation("gelu_fast")
|
get_tf_activation("gelu_fast")
|
||||||
get_tf_activation("mish")
|
get_tf_activation("mish")
|
||||||
|
get_tf_activation("quick_gelu")
|
||||||
|
get_tf_activation("glu")
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
get_tf_activation("bogus")
|
get_tf_activation("bogus")
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
|
|||||||
Reference in New Issue
Block a user