test suite independent of framework

This commit is contained in:
thomwolf
2019-09-05 11:18:55 +02:00
parent 9d0a11a68c
commit 518307dfcd
20 changed files with 596 additions and 262 deletions

View File

@@ -43,11 +43,11 @@ from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CO
# Modeling
try:
import torch
torch_available = True # pylint: disable=invalid-name
_torch_available = True # pylint: disable=invalid-name
except ImportError:
torch_available = False # pylint: disable=invalid-name
_torch_available = False # pylint: disable=invalid-name
if torch_available:
if _torch_available:
logger.info("PyTorch version {} available.".format(torch.__version__))
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
@@ -87,19 +87,26 @@ if torch_available:
# TensorFlow
try:
import tensorflow as tf
tf_available = True # pylint: disable=invalid-name
assert int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
except ImportError:
tf_available = False # pylint: disable=invalid-name
_tf_available = False # pylint: disable=invalid-name
if tf_available:
if _tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_pt_weights_in_bert)
TFBertForMaskedLM, TFBertForNextSentencePrediction, load_bert_pt_weights_in_tf)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
def is_torch_available():
return _torch_available
def is_tf_available():
return _tf_available