From d486795158bb2eea0c4de96246607eb279e4c9c1 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 16 Apr 2020 11:19:24 -0400 Subject: [PATCH] JIT not compatible with PyTorch/XLA (#3743) --- src/transformers/activations.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 004189556c..759b2fa80f 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -1,9 +1,13 @@ +import logging import math import torch import torch.nn.functional as F +logger = logging.getLogger(__name__) + + def swish(x): return x * torch.sigmoid(x) @@ -29,7 +33,15 @@ if torch.__version__ < "1.4.0": gelu = _gelu_python else: gelu = F.gelu - gelu_new = torch.jit.script(gelu_new) + try: + import torch_xla # noqa F401 + + logger.warning( + "The torch_xla package was detected in the python environment. PyTorch/XLA and JIT is untested," + " no activation function will be traced with JIT." + ) + except ImportError: + gelu_new = torch.jit.script(gelu_new) ACT2FN = { "relu": F.relu,