TF: Add sigmoid activation function (#16819)
This commit is contained in:
@@ -152,19 +152,19 @@ class LinearActivation(nn.Module):
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": GELUActivation(),
|
||||
"gelu_10": ClippedGELUActivation(-10, 10),
|
||||
"gelu_fast": FastGELUActivation(),
|
||||
"gelu_new": NewGELUActivation(),
|
||||
"gelu_python": GELUActivation(use_gelu_python=True),
|
||||
"linear": LinearActivation(),
|
||||
"mish": MishActivation(),
|
||||
"quick_gelu": QuickGELUActivation(),
|
||||
"relu": nn.ReLU(),
|
||||
"sigmoid": nn.Sigmoid(),
|
||||
"silu": SiLUActivation(),
|
||||
"swish": SiLUActivation(),
|
||||
"gelu": GELUActivation(),
|
||||
"tanh": nn.Tanh(),
|
||||
"gelu_python": GELUActivation(use_gelu_python=True),
|
||||
"gelu_new": NewGELUActivation(),
|
||||
"gelu_fast": FastGELUActivation(),
|
||||
"quick_gelu": QuickGELUActivation(),
|
||||
"gelu_10": ClippedGELUActivation(-10, 10),
|
||||
"mish": MishActivation(),
|
||||
"linear": LinearActivation(),
|
||||
"sigmoid": nn.Sigmoid(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -113,16 +113,17 @@ else:
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": gelu,
|
||||
"relu": tf.keras.activations.relu,
|
||||
"swish": tf.keras.activations.swish,
|
||||
"silu": tf.keras.activations.swish,
|
||||
"gelu_new": gelu_new,
|
||||
"mish": mish,
|
||||
"tanh": tf.keras.activations.tanh,
|
||||
"gelu_fast": gelu_fast,
|
||||
"quick_gelu": quick_gelu,
|
||||
"gelu_10": gelu_10,
|
||||
"gelu_fast": gelu_fast,
|
||||
"gelu_new": gelu_new,
|
||||
"glu": glu,
|
||||
"mish": mish,
|
||||
"quick_gelu": quick_gelu,
|
||||
"relu": tf.keras.activations.relu,
|
||||
"sigmoid": tf.keras.activations.sigmoid,
|
||||
"silu": tf.keras.activations.swish,
|
||||
"swish": tf.keras.activations.swish,
|
||||
"tanh": tf.keras.activations.tanh,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -46,18 +46,19 @@ class TestActivations(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
|
||||
|
||||
def test_get_activation(self):
|
||||
get_activation("swish")
|
||||
get_activation("silu")
|
||||
get_activation("relu")
|
||||
get_activation("tanh")
|
||||
get_activation("gelu_new")
|
||||
get_activation("gelu_fast")
|
||||
get_activation("gelu_python")
|
||||
get_activation("gelu")
|
||||
get_activation("gelu_10")
|
||||
get_activation("quick_gelu")
|
||||
get_activation("mish")
|
||||
get_activation("gelu_fast")
|
||||
get_activation("gelu_new")
|
||||
get_activation("gelu_python")
|
||||
get_activation("linear")
|
||||
get_activation("mish")
|
||||
get_activation("quick_gelu")
|
||||
get_activation("relu")
|
||||
get_activation("sigmoid")
|
||||
get_activation("silu")
|
||||
get_activation("swish")
|
||||
get_activation("tanh")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
|
||||
@@ -42,17 +42,18 @@ class TestTFActivations(unittest.TestCase):
|
||||
self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
|
||||
|
||||
def test_get_activation(self):
|
||||
get_tf_activation("swish")
|
||||
get_tf_activation("silu")
|
||||
get_tf_activation("gelu")
|
||||
get_tf_activation("relu")
|
||||
get_tf_activation("tanh")
|
||||
get_tf_activation("gelu_new")
|
||||
get_tf_activation("gelu_fast")
|
||||
get_tf_activation("gelu_10")
|
||||
get_tf_activation("gelu_fast")
|
||||
get_tf_activation("gelu_new")
|
||||
get_tf_activation("glu")
|
||||
get_tf_activation("mish")
|
||||
get_tf_activation("quick_gelu")
|
||||
get_tf_activation("glu")
|
||||
get_tf_activation("relu")
|
||||
get_tf_activation("sigmoid")
|
||||
get_tf_activation("silu")
|
||||
get_tf_activation("swish")
|
||||
get_tf_activation("tanh")
|
||||
with self.assertRaises(KeyError):
|
||||
get_tf_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
|
||||
Reference in New Issue
Block a user