From a6bcfb80156fac34c40a3b8dcd973c4a990d75ca Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 25 Sep 2019 21:14:12 +0200 Subject: [PATCH] fix tests --- pytorch_transformers/__init__.py | 4 ---- pytorch_transformers/file_utils.py | 6 ++++-- pytorch_transformers/tokenization_utils.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 5744537cba..7bcb7cafdf 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -56,8 +56,6 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO # Modeling if is_torch_available(): - logger.info("PyTorch version {} available.".format(torch.__version__)) - from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelWithLMHead) @@ -96,8 +94,6 @@ if is_torch_available(): # TensorFlow if is_tf_available(): - logger.info("TensorFlow version {} available.".format(tf.__version__)) - from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelWithLMHead) diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 2c761ef51d..90bdb231f1 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -23,16 +23,20 @@ from botocore.exceptions import ClientError import requests from tqdm import tqdm +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + try: import tensorflow as tf assert int(tf.__version__[0]) >= 2 _tf_available = True # pylint: disable=invalid-name + logger.info("TensorFlow version {} available.".format(tf.__version__)) except (ImportError, AssertionError): _tf_available = False # pylint: disable=invalid-name try: import torch _torch_available = True # pylint: disable=invalid-name + logger.info("PyTorch version {} available.".format(torch.__version__)) except ImportError: _torch_available = False # pylint: disable=invalid-name @@ -67,8 +71,6 @@ TF2_WEIGHTS_NAME = 'tf_model.h5' TF_WEIGHTS_NAME = 'model.ckpt' CONFIG_NAME = "config.json" -logger = logging.getLogger(__name__) # pylint: disable=invalid-name - def is_torch_available(): return _torch_available diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 74797ea206..9a9b141412 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -27,7 +27,7 @@ from .file_utils import cached_path, is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf -if is_torch_available() +if is_torch_available(): import torch logger = logging.getLogger(__name__)