Fix bad handling of env variable USE_TF / USE_TORCH leading to invalid framework being used.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
committed by
Lysandre Debut
parent
23c6998bf4
commit
6e6c8c52ed
@@ -4,7 +4,6 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all
|
|||||||
Copyright by the AllenNLP authors.
|
Copyright by the AllenNLP authors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -26,32 +25,31 @@ from tqdm.auto import tqdm
|
|||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.environ.setdefault("USE_TORCH", "YES")
|
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
||||||
if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"):
|
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||||
else:
|
else:
|
||||||
logger.info("USE_TORCH override through env variable, disabling PyTorch")
|
logger.info("Disabling PyTorch because USE_TF is set")
|
||||||
_torch_available = False
|
_torch_available = False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_torch_available = False # pylint: disable=invalid-name
|
_torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.environ.setdefault("USE_TF", "YES")
|
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
||||||
if os.environ["USE_TF"].upper() in ("1", "ON", "YES"):
|
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||||
_tf_available = True # pylint: disable=invalid-name
|
_tf_available = True # pylint: disable=invalid-name
|
||||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||||
else:
|
else:
|
||||||
logger.info("USE_TF override through env variable, disabling Tensorflow")
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||||
_tf_available = False
|
_tf_available = False
|
||||||
except (ImportError, AssertionError):
|
except (ImportError, AssertionError):
|
||||||
_tf_available = False # pylint: disable=invalid-name
|
_tf_available = False # pylint: disable=invalid-name
|
||||||
@@ -66,7 +64,6 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user