add GPT2 to init - fix weights loading - remove tf.function

This commit is contained in:
thomwolf
2019-09-09 14:29:24 +02:00
parent 78b2a53f10
commit 33cb00f41a
4 changed files with 79 additions and 58 deletions

View File

@@ -99,13 +99,19 @@ if _tf_available:
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead)
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining,
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
TFBertModel, TFBertForPreTraining,
TFBertForMaskedLM, TFBertForNextSentencePrediction,
TFBertForSequenceClassification, TFBertForMultipleChoice,
TFBertForTokenClassification, TFBertForQuestionAnswering,
load_bert_pt_weights_in_tf,
load_bert_pt_weights_in_tf2,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
load_gpt2_pt_weights_in_tf2,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
# Files and general utilities
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,