clean up __init__
This commit is contained in:
@@ -16,7 +16,21 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Tokenizer
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
|
||||
is_tf_available, is_torch_available)
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
||||
# Tokenizers
|
||||
from .tokenization_utils import (PreTrainedTokenizer)
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
@@ -41,13 +55,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
|
||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
# Modeling
|
||||
try:
|
||||
import torch
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
if _torch_available:
|
||||
if is_torch_available():
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
|
||||
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||
@@ -87,14 +95,7 @@ if _torch_available:
|
||||
|
||||
|
||||
# TensorFlow
|
||||
try:
|
||||
import tensorflow as tf
|
||||
assert int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
if _tf_available:
|
||||
if is_tf_available():
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
|
||||
@@ -151,7 +152,8 @@ if _tf_available:
|
||||
load_distilbert_pt_weights_in_tf2,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
if _tf_available and _torch_available:
|
||||
# TF 2.0 <=> PyTorch conversion utilities
|
||||
if is_tf_available() and is_torch_available():
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
@@ -159,17 +161,3 @@ if _tf_available and _torch_available:
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model)
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
|
||||
is_tf_available, is_torch_available)
|
||||
|
||||
from .data import (is_sklearn_available,
|
||||
InputExample, InputFeatures, DataProcessor,
|
||||
glue_output_modes, glue_convert_examples_to_features,
|
||||
glue_processors, glue_tasks_num_labels)
|
||||
|
||||
if is_sklearn_available():
|
||||
from .data import glue_compute_metrics
|
||||
|
||||
Reference in New Issue
Block a user