update __init__ and conversion script
This commit is contained in:
@@ -113,6 +113,11 @@ if _tf_available:
|
|||||||
load_gpt2_pt_weights_in_tf2,
|
load_gpt2_pt_weights_in_tf2,
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer,
|
||||||
|
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel,
|
||||||
|
load_openai_gpt_pt_weights_in_tf2,
|
||||||
|
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer,
|
||||||
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
TFTransfoXLModel, TFTransfoXLLMHeadModel,
|
||||||
load_transfo_xl_pt_weights_in_tf2,
|
load_transfo_xl_pt_weights_in_tf2,
|
||||||
@@ -132,6 +137,19 @@ if _tf_available:
|
|||||||
load_xlm_pt_weights_in_tf2,
|
load_xlm_pt_weights_in_tf2,
|
||||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer,
|
||||||
|
TFRobertaModel, TFRobertaLMHead,
|
||||||
|
TFRobertaForSequenceClassification,
|
||||||
|
load_roberta_pt_weights_in_tf2,
|
||||||
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
||||||
|
TFDistilBertModel, TFDistilBertForMaskedLM,
|
||||||
|
TFDistilBertForSequenceClassification,
|
||||||
|
TFDistilBertForSequenceClassification,
|
||||||
|
load_distilbert_pt_weights_in_tf2,
|
||||||
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
|
|||||||
@@ -24,31 +24,43 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from pytorch_transformers import is_torch_available, cached_path
|
from pytorch_transformers import is_torch_available, cached_path
|
||||||
|
|
||||||
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2,
|
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2,
|
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2,
|
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2,
|
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2,)
|
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
RobertaConfig, TFRobertaLMHead, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,)
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
else:
|
else:
|
||||||
(BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
(BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,) = (
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None, None,
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None, None,
|
RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
None, None, None,
|
DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
|
||||||
None, None, None,
|
None, None,
|
||||||
None, None, None,)
|
None, None,
|
||||||
|
None, None,
|
||||||
|
None, None,
|
||||||
|
None, None,
|
||||||
|
None, None,
|
||||||
|
None, None,
|
||||||
|
None, None,)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -60,6 +72,9 @@ MODEL_CLASSES = {
|
|||||||
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
'roberta': (RobertaConfig, TFRobertaLMHead, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user