get_activation('relu') provides a simple mapping from strings i… (#2807)

* activations.py contains a mapping from string to activation function
* resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
Sam Shleifer
2020-02-13 08:28:33 -05:00
committed by GitHub
parent f54a5bd37f
commit ef74b0f07a
9 changed files with 94 additions and 68 deletions

28
tests/test_activations.py Normal file
View File

@@ -0,0 +1,28 @@
import unittest
from transformers import is_torch_available
from .utils import require_torch
if is_torch_available():
from transformers.activations import _gelu_python, get_activation, gelu_new
import torch
@require_torch
class TestActivations(unittest.TestCase):
def test_gelu_versions(self):
x = torch.Tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
self.assertTrue(torch.eq(_gelu_python(x), torch_builtin(x)).all().item())
self.assertFalse(torch.eq(_gelu_python(x), gelu_new(x)).all().item())
def test_get_activation(self):
get_activation("swish")
get_activation("relu")
get_activation("tanh")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
get_activation(None)