bidirectional conversion TF <=> PT - extended tests
This commit is contained in:
@@ -73,7 +73,8 @@ if _torch_available:
|
||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
@@ -150,6 +151,15 @@ if _tf_available:
|
||||
load_distilbert_pt_weights_in_tf2,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
if _tf_available and _torch_available:
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model)
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
|
||||
Reference in New Issue
Block a user