From f09c45e06753c439d56843d2bd00e7f8c91fc7ee Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 19 Apr 2022 16:13:08 +0100 Subject: [PATCH] TF: Add sigmoid activation function (#16819) --- src/transformers/activations.py | 18 +++++++++--------- src/transformers/activations_tf.py | 17 +++++++++-------- tests/utils/test_activations.py | 19 ++++++++++--------- tests/utils/test_activations_tf.py | 15 ++++++++------- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 07421bfa55..fad8d10613 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -152,19 +152,19 @@ class LinearActivation(nn.Module): ACT2FN = { + "gelu": GELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), + "gelu_fast": FastGELUActivation(), + "gelu_new": NewGELUActivation(), + "gelu_python": GELUActivation(use_gelu_python=True), + "linear": LinearActivation(), + "mish": MishActivation(), + "quick_gelu": QuickGELUActivation(), "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), "silu": SiLUActivation(), "swish": SiLUActivation(), - "gelu": GELUActivation(), "tanh": nn.Tanh(), - "gelu_python": GELUActivation(use_gelu_python=True), - "gelu_new": NewGELUActivation(), - "gelu_fast": FastGELUActivation(), - "quick_gelu": QuickGELUActivation(), - "gelu_10": ClippedGELUActivation(-10, 10), - "mish": MishActivation(), - "linear": LinearActivation(), - "sigmoid": nn.Sigmoid(), } diff --git a/src/transformers/activations_tf.py b/src/transformers/activations_tf.py index ba74e9850e..4fcb1493e4 100644 --- a/src/transformers/activations_tf.py +++ b/src/transformers/activations_tf.py @@ -113,16 +113,17 @@ else: ACT2FN = { "gelu": gelu, - "relu": tf.keras.activations.relu, - "swish": tf.keras.activations.swish, - "silu": tf.keras.activations.swish, - "gelu_new": gelu_new, - "mish": mish, - "tanh": tf.keras.activations.tanh, - "gelu_fast": gelu_fast, - "quick_gelu": quick_gelu, "gelu_10": gelu_10, + "gelu_fast": gelu_fast, + "gelu_new": gelu_new, "glu": glu, + "mish": mish, + "quick_gelu": quick_gelu, + "relu": tf.keras.activations.relu, + "sigmoid": tf.keras.activations.sigmoid, + "silu": tf.keras.activations.swish, + "swish": tf.keras.activations.swish, + "tanh": tf.keras.activations.tanh, } diff --git a/tests/utils/test_activations.py b/tests/utils/test_activations.py index 339d2eda16..29e487ee97 100644 --- a/tests/utils/test_activations.py +++ b/tests/utils/test_activations.py @@ -46,18 +46,19 @@ class TestActivations(unittest.TestCase): self.assertTrue(torch.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)) def test_get_activation(self): - get_activation("swish") - get_activation("silu") - get_activation("relu") - get_activation("tanh") - get_activation("gelu_new") - get_activation("gelu_fast") - get_activation("gelu_python") + get_activation("gelu") get_activation("gelu_10") - get_activation("quick_gelu") - get_activation("mish") + get_activation("gelu_fast") + get_activation("gelu_new") + get_activation("gelu_python") get_activation("linear") + get_activation("mish") + get_activation("quick_gelu") + get_activation("relu") get_activation("sigmoid") + get_activation("silu") + get_activation("swish") + get_activation("tanh") with self.assertRaises(KeyError): get_activation("bogus") with self.assertRaises(KeyError): diff --git a/tests/utils/test_activations_tf.py b/tests/utils/test_activations_tf.py index 9c02639835..8d418d7fe3 100644 --- a/tests/utils/test_activations_tf.py +++ b/tests/utils/test_activations_tf.py @@ -42,17 +42,18 @@ class TestTFActivations(unittest.TestCase): self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)) def test_get_activation(self): - get_tf_activation("swish") - get_tf_activation("silu") get_tf_activation("gelu") - get_tf_activation("relu") - get_tf_activation("tanh") - get_tf_activation("gelu_new") - get_tf_activation("gelu_fast") get_tf_activation("gelu_10") + get_tf_activation("gelu_fast") + get_tf_activation("gelu_new") + get_tf_activation("glu") get_tf_activation("mish") get_tf_activation("quick_gelu") - get_tf_activation("glu") + get_tf_activation("relu") + get_tf_activation("sigmoid") + get_tf_activation("silu") + get_tf_activation("swish") + get_tf_activation("tanh") with self.assertRaises(KeyError): get_tf_activation("bogus") with self.assertRaises(KeyError):