Fix gelu test for torch 1.10 (#14167)

This commit is contained in:
Lysandre Debut
2021-10-26 22:20:51 -04:00
committed by GitHub
parent 8ddbfe9752
commit 1e53faeb2e

View File

@@ -29,8 +29,8 @@ class TestActivations(unittest.TestCase):
def test_gelu_versions(self): def test_gelu_versions(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100]) x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu") torch_builtin = get_activation("gelu")
self.assertTrue(torch.eq(_gelu_python(x), torch_builtin(x)).all().item()) self.assertTrue(torch.allclose(_gelu_python(x), torch_builtin(x)))
self.assertFalse(torch.eq(_gelu_python(x), gelu_new(x)).all().item()) self.assertFalse(torch.allclose(_gelu_python(x), gelu_new(x)))
def test_get_activation(self): def test_get_activation(self):
get_activation("swish") get_activation("swish")