From c6acd246ec90857b70f449dcbcb1543f150821fc Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 3 Apr 2020 22:20:21 +0300 Subject: [PATCH] Speed up GELU computation with torch.jit (#2988) * Compile gelu_new with torchscript * Compile _gelu_python with torchscript * Wrap gelu_new with torch.jit for torch>=1.4 --- src/transformers/activations.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 7968b88ba9..004189556c 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -18,12 +18,6 @@ def _gelu_python(x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) -if torch.__version__ < "1.4.0": - gelu = _gelu_python -else: - gelu = F.gelu - - def gelu_new(x): """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). Also see https://arxiv.org/abs/1606.08415 @@ -31,6 +25,12 @@ def gelu_new(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) +if torch.__version__ < "1.4.0": + gelu = _gelu_python +else: + gelu = F.gelu + gelu_new = torch.jit.script(gelu_new) + ACT2FN = { "relu": F.relu, "swish": swish,