Depreciate pythonic Mish and support PyTorch 1.9 version of Mish (#12240)

* Moved Mish to Torch 1.9 version

* Run black formatting
This commit is contained in:
Xa9aX ツ
2021-06-18 18:43:45 +05:30
committed by GitHub
parent 47a9768334
commit f3558bbcfd
3 changed files with 11 additions and 9 deletions

View File

@@ -73,10 +73,20 @@ else:
silu = nn.functional.silu
def mish(x):
def _mish_python(x):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x))
if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def linear_act(x):
return x

View File

@@ -140,10 +140,6 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
class NoNorm(nn.Module):
def __init__(self, feat_size, eps=None):
super().__init__()

View File

@@ -138,10 +138,6 @@ def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_ch
return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""