Refactoring the TF activations functions (#7150)

* Refactoring the activations functions into a common file

* Apply style

* remove unused import

* fix tests

* Fix tests.
This commit is contained in:
Julien Plu
2020-09-16 13:03:47 +02:00
committed by GitHub
parent b00cafbde5
commit af8425b749
14 changed files with 126 additions and 211 deletions

View File

@@ -0,0 +1,24 @@
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf
if is_tf_available():
from transformers.activations_tf import get_tf_activation
@require_tf
class TestTFActivations(unittest.TestCase):
def test_get_activation(self):
get_tf_activation("swish")
get_tf_activation("gelu")
get_tf_activation("relu")
get_tf_activation("tanh")
get_tf_activation("gelu_new")
get_tf_activation("gelu_fast")
get_tf_activation("mish")
with self.assertRaises(KeyError):
get_tf_activation("bogus")
with self.assertRaises(KeyError):
get_tf_activation(None)