JIT not compatible with PyTorch/XLA (#3743)

This commit is contained in:
Lysandre Debut
2020-04-16 11:19:24 -04:00
committed by GitHub
parent b1e2368b32
commit d486795158

View File

@@ -1,9 +1,13 @@
import logging
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
logger = logging.getLogger(__name__)
def swish(x): def swish(x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
@@ -29,6 +33,14 @@ if torch.__version__ < "1.4.0":
gelu = _gelu_python gelu = _gelu_python
else: else:
gelu = F.gelu gelu = F.gelu
try:
import torch_xla # noqa F401
logger.warning(
"The torch_xla package was detected in the python environment. PyTorch/XLA and JIT is untested,"
" no activation function will be traced with JIT."
)
except ImportError:
gelu_new = torch.jit.script(gelu_new) gelu_new = torch.jit.script(gelu_new)
ACT2FN = { ACT2FN = {