From fb0bd7b7a8ca7a7b7034c0c7b9c2b7ff116aed8d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 18 Oct 2022 11:56:45 -0400 Subject: [PATCH] Fix activations being all the same module (#19728) --- src/transformers/activations.py | 37 ++++++++++++++++++++------------- tests/utils/test_activations.py | 8 +++++++ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index f7c9046134..aaf4d81625 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections import OrderedDict import torch from packaging import version @@ -141,21 +142,29 @@ class LinearActivation(nn.Module): return input -ACT2FN = { - "gelu": GELUActivation(), - "gelu_10": ClippedGELUActivation(-10, 10), - "gelu_fast": FastGELUActivation(), - "gelu_new": NewGELUActivation(), - "gelu_python": GELUActivation(use_gelu_python=True), - "linear": LinearActivation(), - "mish": MishActivation(), - "quick_gelu": QuickGELUActivation(), - "relu": nn.ReLU(), - "sigmoid": nn.Sigmoid(), - "silu": SiLUActivation(), - "swish": SiLUActivation(), - "tanh": nn.Tanh(), +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "gelu": GELUActivation, + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), + "gelu_fast": FastGELUActivation, + "gelu_new": NewGELUActivation, + "gelu_python": (GELUActivation, {"use_gelu_python": True}), + "linear": LinearActivation, + "mish": MishActivation, + "quick_gelu": QuickGELUActivation, + "relu": nn.ReLU, + "sigmoid": nn.Sigmoid, + "silu": SiLUActivation, + "swish": SiLUActivation, + "tanh": nn.Tanh, } +ACT2FN = ClassInstantier(ACT2CLS) def get_activation(activation_string): diff --git a/tests/utils/test_activations.py b/tests/utils/test_activations.py index 29e487ee97..1e301f948a 100644 --- a/tests/utils/test_activations.py +++ b/tests/utils/test_activations.py @@ -63,3 +63,11 @@ class TestActivations(unittest.TestCase): get_activation("bogus") with self.assertRaises(KeyError): get_activation(None) + + def test_activations_are_distinct_objects(self): + act1 = get_activation("gelu") + act1.a = 1 + act2 = get_activation("gelu") + self.assertEqual(act1.a, 1) + with self.assertRaises(AttributeError): + _ = act2.a