From f3558bbcfdfff25abe0137d00f8b4f88fb58eed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xa9aX=20=E3=83=84?= Date: Fri, 18 Jun 2021 18:43:45 +0530 Subject: [PATCH] Depreciate pythonic Mish and support PyTorch 1.9 version of Mish (#12240) * Moved Mish to Torch 1.9 version * Run black formatting --- src/transformers/activations.py | 12 +++++++++++- .../models/mobilebert/modeling_mobilebert.py | 4 ---- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 4 ---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 6ef4434808..30301613ae 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -73,10 +73,20 @@ else: silu = nn.functional.silu -def mish(x): +def _mish_python(x): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ return x * torch.tanh(nn.functional.softplus(x)) +if version.parse(torch.__version__) < version.parse("1.9"): + mish = _mish_python +else: + mish = nn.functional.mish + + def linear_act(x): return x diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 3a855ba4fb..448a894beb 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -140,10 +140,6 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): return model -def mish(x): - return x * torch.tanh(nn.functional.softplus(x)) - - class NoNorm(nn.Module): def __init__(self, feat_size, eps=None): super().__init__() diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index c4e6278459..87a95e6b3b 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -138,10 +138,6 @@ def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_ch return model -def mish(x): - return x * torch.tanh(nn.functional.softplus(x)) - - # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}} class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings."""