Speed up GELU computation with torch.jit (#2988)
* Compile gelu_new with torchscript * Compile _gelu_python with torchscript * Wrap gelu_new with torch.jit for torch>=1.4
This commit is contained in:
@@ -18,12 +18,6 @@ def _gelu_python(x):
|
|||||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
if torch.__version__ < "1.4.0":
|
|
||||||
gelu = _gelu_python
|
|
||||||
else:
|
|
||||||
gelu = F.gelu
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_new(x):
|
def gelu_new(x):
|
||||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||||
Also see https://arxiv.org/abs/1606.08415
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
@@ -31,6 +25,12 @@ def gelu_new(x):
|
|||||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
|
|
||||||
|
if torch.__version__ < "1.4.0":
|
||||||
|
gelu = _gelu_python
|
||||||
|
else:
|
||||||
|
gelu = F.gelu
|
||||||
|
gelu_new = torch.jit.script(gelu_new)
|
||||||
|
|
||||||
ACT2FN = {
|
ACT2FN = {
|
||||||
"relu": F.relu,
|
"relu": F.relu,
|
||||||
"swish": swish,
|
"swish": swish,
|
||||||
|
|||||||
Reference in New Issue
Block a user