test suite independent of framework
This commit is contained in:
@@ -43,11 +43,11 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
|
||||
# Modeling
|
||||
try:
|
||||
import torch
|
||||
torch_available = True # pylint: disable=invalid-name
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
torch_available = False # pylint: disable=invalid-name
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
if torch_available:
|
||||
if _torch_available:
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
|
||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||
@@ -87,19 +87,26 @@ if torch_available:
|
||||
# TensorFlow
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf_available = True # pylint: disable=invalid-name
|
||||
assert int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
tf_available = False # pylint: disable=invalid-name
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
if tf_available:
|
||||
if _tf_available:
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
|
||||
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_bert_pt_weights_in_tf)
|
||||
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
Reference in New Issue
Block a user