Fix activations being all the same module (#19728)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
@@ -141,21 +142,29 @@ class LinearActivation(nn.Module):
|
||||
return input
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": GELUActivation(),
|
||||
"gelu_10": ClippedGELUActivation(-10, 10),
|
||||
"gelu_fast": FastGELUActivation(),
|
||||
"gelu_new": NewGELUActivation(),
|
||||
"gelu_python": GELUActivation(use_gelu_python=True),
|
||||
"linear": LinearActivation(),
|
||||
"mish": MishActivation(),
|
||||
"quick_gelu": QuickGELUActivation(),
|
||||
"relu": nn.ReLU(),
|
||||
"sigmoid": nn.Sigmoid(),
|
||||
"silu": SiLUActivation(),
|
||||
"swish": SiLUActivation(),
|
||||
"tanh": nn.Tanh(),
|
||||
class ClassInstantier(OrderedDict):
|
||||
def __getitem__(self, key):
|
||||
content = super().__getitem__(key)
|
||||
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
ACT2CLS = {
|
||||
"gelu": GELUActivation,
|
||||
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
||||
"gelu_fast": FastGELUActivation,
|
||||
"gelu_new": NewGELUActivation,
|
||||
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
||||
"linear": LinearActivation,
|
||||
"mish": MishActivation,
|
||||
"quick_gelu": QuickGELUActivation,
|
||||
"relu": nn.ReLU,
|
||||
"sigmoid": nn.Sigmoid,
|
||||
"silu": SiLUActivation,
|
||||
"swish": SiLUActivation,
|
||||
"tanh": nn.Tanh,
|
||||
}
|
||||
ACT2FN = ClassInstantier(ACT2CLS)
|
||||
|
||||
|
||||
def get_activation(activation_string):
|
||||
|
||||
@@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase):
|
||||
get_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation(None)
|
||||
|
||||
def test_activations_are_distinct_objects(self):
|
||||
act1 = get_activation("gelu")
|
||||
act1.a = 1
|
||||
act2 = get_activation("gelu")
|
||||
self.assertEqual(act1.a, 1)
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = act2.a
|
||||
|
||||
Reference in New Issue
Block a user