From 6e6c8c52ed8a60063d4ec1ed7f7eb79d7da126ef Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 10 Jan 2020 14:20:10 +0100 Subject: [PATCH] Fix bad handling of env variable USE_TF / USE_TORCH leading to invalid framework being used. Signed-off-by: Morgan Funtowicz --- src/transformers/file_utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 3b90dca7c2..a0489a4e06 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -4,7 +4,6 @@ This file is adapted from the AllenNLP library at https://github.com/allenai/all Copyright by the AllenNLP authors. """ - import fnmatch import json import logging @@ -26,32 +25,31 @@ from tqdm.auto import tqdm from . import __version__ - logger = logging.getLogger(__name__) # pylint: disable=invalid-name try: - os.environ.setdefault("USE_TORCH", "YES") - if os.environ["USE_TORCH"].upper() in ("1", "ON", "YES"): + if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ + os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"): import torch _torch_available = True # pylint: disable=invalid-name logger.info("PyTorch version {} available.".format(torch.__version__)) else: - logger.info("USE_TORCH override through env variable, disabling PyTorch") + logger.info("Disabling PyTorch because USE_TF is set") _torch_available = False except ImportError: _torch_available = False # pylint: disable=invalid-name try: - os.environ.setdefault("USE_TF", "YES") - if os.environ["USE_TF"].upper() in ("1", "ON", "YES"): + if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ + os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"): import tensorflow as tf assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 _tf_available = True # pylint: disable=invalid-name logger.info("TensorFlow version {} available.".format(tf.__version__)) else: - logger.info("USE_TF override through env variable, disabling Tensorflow") + logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False except (ImportError, AssertionError): _tf_available = False # pylint: disable=invalid-name @@ -66,7 +64,6 @@ except ImportError: ) default_cache_path = os.path.join(torch_cache_home, "transformers") - try: from pathlib import Path