|
|
|
@@ -24,43 +24,43 @@ import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
|
|
from pytorch_transformers import is_torch_available, cached_path
|
|
|
|
from pytorch_transformers import is_torch_available, cached_path
|
|
|
|
|
|
|
|
|
|
|
|
from pytorch_transformers import (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
from pytorch_transformers import (BertConfig, TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
|
|
DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
|
|
|
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
|
|
|
|
|
|
|
|
|
|
|
if is_torch_available():
|
|
|
|
if is_torch_available():
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
from pytorch_transformers import (BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
from pytorch_transformers import (BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
|
|
|
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
(BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
(BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
|
|
DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
|
|
|
|
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,) = (
|
|
|
|
|
|
|
|
None, None, None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None,
|
|
|
|
None, None, None,
|
|
|
|
None, None,
|
|
|
|
None, None, None,)
|
|
|
|
None, None,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
@@ -68,22 +68,29 @@ logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CLASSES = {
|
|
|
|
MODEL_CLASSES = {
|
|
|
|
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'bert': (BertConfig, TFBertForPreTraining, load_bert_pt_weights_in_tf2, BertForPreTraining, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
|
|
|
|
'bert-large-uncased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
|
|
|
|
'bert-large-cased-whole-word-masking-finetuned-squad': (BertConfig, TFBertForQuestionAnswering, load_bert_pt_weights_in_tf2, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
|
|
|
|
'bert-base-cased-finetuned-mrpc': (BertConfig, TFBertForSequenceClassification, load_bert_pt_weights_in_tf2, BertForSequenceClassification, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'gpt2': (GPT2Config, TFGPT2LMHeadModel, load_gpt2_pt_weights_in_tf2, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'xlnet': (XLNetConfig, TFXLNetLMHeadModel, load_xlnet_pt_weights_in_tf2, XLNetLMHeadModel, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'xlm': (XLMConfig, TFXLMWithLMHeadModel, load_xlm_pt_weights_in_tf2, XLMWithLMHeadModel, XLM_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'transfo-xl': (TransfoXLConfig, TFTransfoXLLMHeadModel, load_transfo_xl_pt_weights_in_tf2, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'openai-gpt': (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, load_openai_gpt_pt_weights_in_tf2, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'roberta': (RobertaConfig, TFRobertaForMaskedLM, load_roberta_pt_weights_in_tf2, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
|
|
|
|
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, load_roberta_pt_weights_in_tf2, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, load_distilbert_pt_weights_in_tf2, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
|
|
|
|
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, load_distilbert_pt_weights_in_tf2, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False):
|
|
|
|
def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True):
|
|
|
|
if model_type not in MODEL_CLASSES:
|
|
|
|
if model_type not in MODEL_CLASSES:
|
|
|
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
|
|
|
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
|
|
|
|
|
|
|
|
|
|
|
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
|
|
|
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
|
|
|
|
|
|
|
|
|
|
|
# Initialise TF model
|
|
|
|
# Initialise TF model
|
|
|
|
|
|
|
|
if config_file in aws_config_map:
|
|
|
|
|
|
|
|
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
|
|
|
|
config = config_class.from_json_file(config_file)
|
|
|
|
config = config_class.from_json_file(config_file)
|
|
|
|
config.output_hidden_states = True
|
|
|
|
config.output_hidden_states = True
|
|
|
|
config.output_attentions = True
|
|
|
|
config.output_attentions = True
|
|
|
|
@@ -91,6 +98,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|
|
|
tf_model = model_class(config)
|
|
|
|
tf_model = model_class(config)
|
|
|
|
|
|
|
|
|
|
|
|
# Load weights from tf checkpoint
|
|
|
|
# Load weights from tf checkpoint
|
|
|
|
|
|
|
|
if pytorch_checkpoint_path in aws_model_maps:
|
|
|
|
|
|
|
|
pytorch_checkpoint_path = cached_path(aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models)
|
|
|
|
tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
|
|
|
|
tf_model = loading_fct(tf_model, pytorch_checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
if compare_with_pt_model:
|
|
|
|
if compare_with_pt_model:
|
|
|
|
@@ -117,7 +126,8 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|
|
|
tf_model.save_weights(tf_dump_path, save_format='h5')
|
|
|
|
tf_model.save_weights(tf_dump_path, save_format='h5')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False):
|
|
|
|
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
|
|
|
|
|
|
|
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
|
|
|
|
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
|
|
|
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
|
|
|
|
|
|
|
|
|
|
|
if args_model_type is None:
|
|
|
|
if args_model_type is None:
|
|
|
|
@@ -134,20 +144,39 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
|
|
|
|
|
|
|
|
|
|
|
|
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
|
|
|
config_class, model_class, loading_fct, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
|
|
|
|
|
|
|
|
|
|
|
for i, shortcut_name in enumerate(aws_config_map.keys(), start=1):
|
|
|
|
if model_shortcut_names_or_path is None:
|
|
|
|
|
|
|
|
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
|
|
|
|
|
|
|
if config_shortcut_names_or_path is None:
|
|
|
|
|
|
|
|
config_shortcut_names_or_path = model_shortcut_names_or_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
|
|
|
|
|
|
|
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1):
|
|
|
|
print("-" * 100)
|
|
|
|
print("-" * 100)
|
|
|
|
print(" Converting checkpoint {}/{}: {}".format(i, len(aws_config_map), shortcut_name))
|
|
|
|
if '-squad' in model_shortcut_name or '-mrpc' in model_shortcut_name or '-mnli' in model_shortcut_name:
|
|
|
|
print("-" * 100)
|
|
|
|
if not only_convert_finetuned_models:
|
|
|
|
if 'finetuned' in shortcut_name:
|
|
|
|
print(" Skipping finetuned checkpoint {}".format(model_shortcut_name))
|
|
|
|
print(" Skipping finetuned checkpoint ")
|
|
|
|
continue
|
|
|
|
|
|
|
|
model_type = model_shortcut_name
|
|
|
|
|
|
|
|
elif only_convert_finetuned_models:
|
|
|
|
|
|
|
|
print(" Skipping not finetuned checkpoint {}".format(model_shortcut_name))
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models)
|
|
|
|
print(" Converting checkpoint {}/{}: {} - model_type {}".format(i, len(aws_config_map), model_shortcut_name, model_type))
|
|
|
|
model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models)
|
|
|
|
print("-" * 100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if config_shortcut_name in aws_config_map:
|
|
|
|
|
|
|
|
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_shortcut_name in aws_model_maps:
|
|
|
|
|
|
|
|
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
|
|
|
|
|
|
|
|
|
|
|
convert_pt_checkpoint_to_tf(model_type,
|
|
|
|
convert_pt_checkpoint_to_tf(model_type,
|
|
|
|
model_file,
|
|
|
|
model_file,
|
|
|
|
config_file,
|
|
|
|
config_file,
|
|
|
|
os.path.join(tf_dump_path, shortcut_name + '-tf_model.h5'),
|
|
|
|
os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
|
|
|
compare_with_pt_model=compare_with_pt_model)
|
|
|
|
compare_with_pt_model=compare_with_pt_model)
|
|
|
|
os.remove(config_file)
|
|
|
|
os.remove(config_file)
|
|
|
|
os.remove(model_file)
|
|
|
|
os.remove(model_file)
|
|
|
|
@@ -176,23 +205,29 @@ if __name__ == "__main__":
|
|
|
|
help = "The config json file corresponding to the pre-trained model. \n"
|
|
|
|
help = "The config json file corresponding to the pre-trained model. \n"
|
|
|
|
"This specifies the model architecture. If not given and "
|
|
|
|
"This specifies the model architecture. If not given and "
|
|
|
|
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
|
|
|
"--pytorch_checkpoint_path is not given or is a shortcut name"
|
|
|
|
"use the configuration associated to teh shortcut name on the AWS")
|
|
|
|
"use the configuration associated to the shortcut name on the AWS")
|
|
|
|
parser.add_argument("--compare_with_pt_model",
|
|
|
|
parser.add_argument("--compare_with_pt_model",
|
|
|
|
action='store_true',
|
|
|
|
action='store_true',
|
|
|
|
help = "Compare Tensorflow and PyTorch model predictions.")
|
|
|
|
help = "Compare Tensorflow and PyTorch model predictions.")
|
|
|
|
parser.add_argument("--use_cached_models",
|
|
|
|
parser.add_argument("--use_cached_models",
|
|
|
|
action='store_true',
|
|
|
|
action='store_true',
|
|
|
|
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
|
|
|
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
|
|
|
|
|
|
|
parser.add_argument("--only_convert_finetuned_models",
|
|
|
|
|
|
|
|
action='store_true',
|
|
|
|
|
|
|
|
help = "Only convert finetuned models.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
if args.pytorch_checkpoint_path is not None:
|
|
|
|
# if args.pytorch_checkpoint_path is not None:
|
|
|
|
convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
|
|
|
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
|
|
|
args.pytorch_checkpoint_path,
|
|
|
|
# args.pytorch_checkpoint_path,
|
|
|
|
args.config_file,
|
|
|
|
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
|
|
|
|
args.tf_dump_path,
|
|
|
|
# args.tf_dump_path,
|
|
|
|
compare_with_pt_model=args.compare_with_pt_model)
|
|
|
|
# compare_with_pt_model=args.compare_with_pt_model,
|
|
|
|
else:
|
|
|
|
# use_cached_models=args.use_cached_models)
|
|
|
|
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
|
|
|
# else:
|
|
|
|
args.tf_dump_path,
|
|
|
|
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
|
|
|
compare_with_pt_model=args.compare_with_pt_model,
|
|
|
|
args.tf_dump_path,
|
|
|
|
use_cached_models=args.use_cached_models)
|
|
|
|
model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
|
|
|
|
|
|
|
|
compare_with_pt_model=args.compare_with_pt_model,
|
|
|
|
|
|
|
|
use_cached_models=args.use_cached_models,
|
|
|
|
|
|
|
|
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
|
|
|
|