WIP XLNet

This commit is contained in:
thomwolf
2019-09-10 12:17:18 +02:00
parent f851fb55ca
commit 32aabe8c33
7 changed files with 1540 additions and 68 deletions

View File

@@ -95,7 +95,7 @@ except (ImportError, AssertionError):
if _tf_available:
logger.info("TensorFlow version {} available.".format(tf.__version__))
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead)
@@ -107,7 +107,7 @@ if _tf_available:
load_bert_pt_weights_in_tf2,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings,
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)