Fix activations being all the same module (#19728)
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -141,21 +142,29 @@ class LinearActivation(nn.Module):
|
|||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
ACT2FN = {
|
class ClassInstantier(OrderedDict):
|
||||||
"gelu": GELUActivation(),
|
def __getitem__(self, key):
|
||||||
"gelu_10": ClippedGELUActivation(-10, 10),
|
content = super().__getitem__(key)
|
||||||
"gelu_fast": FastGELUActivation(),
|
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
||||||
"gelu_new": NewGELUActivation(),
|
return cls(**kwargs)
|
||||||
"gelu_python": GELUActivation(use_gelu_python=True),
|
|
||||||
"linear": LinearActivation(),
|
|
||||||
"mish": MishActivation(),
|
ACT2CLS = {
|
||||||
"quick_gelu": QuickGELUActivation(),
|
"gelu": GELUActivation,
|
||||||
"relu": nn.ReLU(),
|
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
||||||
"sigmoid": nn.Sigmoid(),
|
"gelu_fast": FastGELUActivation,
|
||||||
"silu": SiLUActivation(),
|
"gelu_new": NewGELUActivation,
|
||||||
"swish": SiLUActivation(),
|
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
||||||
"tanh": nn.Tanh(),
|
"linear": LinearActivation,
|
||||||
|
"mish": MishActivation,
|
||||||
|
"quick_gelu": QuickGELUActivation,
|
||||||
|
"relu": nn.ReLU,
|
||||||
|
"sigmoid": nn.Sigmoid,
|
||||||
|
"silu": SiLUActivation,
|
||||||
|
"swish": SiLUActivation,
|
||||||
|
"tanh": nn.Tanh,
|
||||||
}
|
}
|
||||||
|
ACT2FN = ClassInstantier(ACT2CLS)
|
||||||
|
|
||||||
|
|
||||||
def get_activation(activation_string):
|
def get_activation(activation_string):
|
||||||
|
|||||||
@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase):
|
|||||||
get_activation("bogus")
|
get_activation("bogus")
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
get_activation(None)
|
get_activation(None)
|
||||||
|
|
||||||
|
def test_activations_are_distinct_objects(self):
|
||||||
|
act1 = get_activation("gelu")
|
||||||
|
act1.a = 1
|
||||||
|
act2 = get_activation("gelu")
|
||||||
|
self.assertEqual(act1.a, 1)
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
_ = act2.a
|
||||||
|
|||||||
Reference in New Issue
Block a user