bidirectional conversion TF <=> PT - extended tests

This commit is contained in:
thomwolf
2019-09-24 13:25:50 +02:00
parent a7e01a248b
commit e9a103c17a
7 changed files with 313 additions and 44 deletions

View File

@@ -73,7 +73,8 @@ if _torch_available:
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
XLMWithLMHeadModel, XLMForSequenceClassification,
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
@@ -150,6 +151,15 @@ if _tf_available:
load_distilbert_pt_weights_in_tf2,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
if _tf_available and _torch_available:
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
load_pytorch_checkpoint_in_tf2_model,
load_pytorch_weights_in_tf2_model,
load_pytorch_model_in_tf2_model,
load_tf2_checkpoint_in_pytorch_model,
load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,