JIT not compatible with PyTorch/XLA (#3743)
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def swish(x):
|
def swish(x):
|
||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
@@ -29,7 +33,15 @@ if torch.__version__ < "1.4.0":
|
|||||||
gelu = _gelu_python
|
gelu = _gelu_python
|
||||||
else:
|
else:
|
||||||
gelu = F.gelu
|
gelu = F.gelu
|
||||||
gelu_new = torch.jit.script(gelu_new)
|
try:
|
||||||
|
import torch_xla # noqa F401
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"The torch_xla package was detected in the python environment. PyTorch/XLA and JIT is untested,"
|
||||||
|
" no activation function will be traced with JIT."
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
gelu_new = torch.jit.script(gelu_new)
|
||||||
|
|
||||||
ACT2FN = {
|
ACT2FN = {
|
||||||
"relu": F.relu,
|
"relu": F.relu,
|
||||||
|
|||||||
Reference in New Issue
Block a user