Implementation of activations as pytorch modules (#15616)

* Implement activations as pytorch modules

* Apply fixup

* Add missing tests for activations

* Update docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Eldar Kurtic
2022-02-16 20:37:52 +01:00
committed by GitHub
parent 66828a19b1
commit f65fe3663a
2 changed files with 108 additions and 47 deletions

View File

@@ -16,7 +16,7 @@ import math
import torch import torch
from packaging import version from packaging import version
from torch import nn from torch import Tensor, nn
from .utils import logging from .utils import logging
@@ -24,39 +24,66 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def gelu_python(x): class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class GELUActivation(nn.Module):
""" """
Original Implementation of the GELU activation function in Google BERT repo when initially created. For Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def __init__(self, use_gelu_python: bool = False):
def gelu_new(x): super().__init__()
""" if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python:
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see self.act = self._gelu_python
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = gelu_python
else: else:
gelu = nn.functional.gelu self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
def gelu_fast(x): class FastGELUActivation(nn.Module):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) """
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
def quick_gelu(x): class QuickGELUActivation(nn.Module):
return x * torch.sigmoid(1.702 * x) """
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
def _silu_python(x): class SiLUActivation(nn.Module):
""" """
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
@@ -64,46 +91,65 @@ def _silu_python(x):
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later. later.
""" """
return x * torch.sigmoid(x)
def __init__(self):
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python self.act = self._silu_python
else: else:
silu = nn.functional.silu self.act = nn.functional.silu
def _silu_python(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(input)
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
def _mish_python(x): class MishActivation(nn.Module):
""" """
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish visit the official repository for the paper: https://github.com/digantamisra98/Mish
""" """
return x * torch.tanh(nn.functional.softplus(x))
def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.9"): if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python self.act = self._mish_python
else: else:
mish = nn.functional.mish self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * torch.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
def linear_act(x): class LinearActivation(nn.Module):
return x """
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def __init__(self):
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return input
ACT2FN = { ACT2FN = {
"relu": nn.functional.relu, "relu": nn.ReLU(),
"silu": silu, "silu": SiLUActivation(),
"swish": silu, "swish": SiLUActivation(),
"gelu": gelu, "gelu": GELUActivation(),
"tanh": torch.tanh, "tanh": nn.Tanh(),
"gelu_python": gelu_python, "gelu_python": GELUActivation(use_gelu_python=True),
"gelu_new": gelu_new, "gelu_new": NewGELUActivation(),
"gelu_fast": gelu_fast, "gelu_fast": FastGELUActivation(),
"quick_gelu": quick_gelu, "quick_gelu": QuickGELUActivation(),
"mish": mish, "mish": MishActivation(),
"linear": linear_act, "linear": LinearActivation(),
"sigmoid": torch.sigmoid, "sigmoid": nn.Sigmoid(),
} }
@@ -112,3 +158,14 @@ def get_activation(activation_string):
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

View File

@@ -40,6 +40,10 @@ class TestActivations(unittest.TestCase):
get_activation("gelu_new") get_activation("gelu_new")
get_activation("gelu_fast") get_activation("gelu_fast")
get_activation("gelu_python") get_activation("gelu_python")
get_activation("quick_gelu")
get_activation("mish")
get_activation("linear")
get_activation("sigmoid")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
get_activation("bogus") get_activation("bogus")
with self.assertRaises(KeyError): with self.assertRaises(KeyError):