Depreciate pythonic Mish and support PyTorch 1.9 version of Mish (#12240)
* Moved Mish to Torch 1.9 version * Run black formatting
This commit is contained in:
@@ -73,10 +73,20 @@ else:
|
|||||||
silu = nn.functional.silu
|
silu = nn.functional.silu
|
||||||
|
|
||||||
|
|
||||||
def mish(x):
|
def _mish_python(x):
|
||||||
|
"""
|
||||||
|
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
||||||
|
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
||||||
|
"""
|
||||||
return x * torch.tanh(nn.functional.softplus(x))
|
return x * torch.tanh(nn.functional.softplus(x))
|
||||||
|
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.9"):
|
||||||
|
mish = _mish_python
|
||||||
|
else:
|
||||||
|
mish = nn.functional.mish
|
||||||
|
|
||||||
|
|
||||||
def linear_act(x):
|
def linear_act(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -140,10 +140,6 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def mish(x):
|
|
||||||
return x * torch.tanh(nn.functional.softplus(x))
|
|
||||||
|
|
||||||
|
|
||||||
class NoNorm(nn.Module):
|
class NoNorm(nn.Module):
|
||||||
def __init__(self, feat_size, eps=None):
|
def __init__(self, feat_size, eps=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -138,10 +138,6 @@ def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_ch
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def mish(x):
|
|
||||||
return x * torch.tanh(nn.functional.softplus(x))
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
|
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
|
||||||
class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|||||||
Reference in New Issue
Block a user