updating
This commit is contained in:
@@ -119,10 +119,10 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
tf_inputs = tf.constant(inputs_list)
|
tf_inputs = tf.constant(inputs_list)
|
||||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||||
|
|
||||||
pt_model = pt_model_class(config)
|
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
|
||||||
pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
|
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
|
||||||
strict-False)
|
config=config,
|
||||||
pt_model.eval()
|
state_dict=state_dict)
|
||||||
|
|
||||||
pt_inputs = torch.tensor(inputs_list)
|
pt_inputs = torch.tensor(inputs_list)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -140,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
|||||||
|
|
||||||
|
|
||||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
||||||
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
|
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
|
||||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||||
|
|
||||||
if args_model_type is None:
|
if args_model_type is None:
|
||||||
@@ -188,11 +188,13 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
|
|
||||||
if os.path.isfile(model_shortcut_name):
|
if os.path.isfile(model_shortcut_name):
|
||||||
model_shortcut_name = 'converted_model'
|
model_shortcut_name = 'converted_model'
|
||||||
|
|
||||||
convert_pt_checkpoint_to_tf(model_type=model_type,
|
convert_pt_checkpoint_to_tf(model_type=model_type,
|
||||||
pytorch_checkpoint_path=model_file,
|
pytorch_checkpoint_path=model_file,
|
||||||
config_file=config_file,
|
config_file=config_file,
|
||||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
||||||
compare_with_pt_model=compare_with_pt_model)
|
compare_with_pt_model=compare_with_pt_model)
|
||||||
|
if remove_cached_files:
|
||||||
os.remove(config_file)
|
os.remove(config_file)
|
||||||
os.remove(model_file)
|
os.remove(model_file)
|
||||||
|
|
||||||
@@ -227,6 +229,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--use_cached_models",
|
parser.add_argument("--use_cached_models",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
||||||
|
parser.add_argument("--remove_cached_files",
|
||||||
|
action='store_true',
|
||||||
|
help = "Remove pytorch models after conversion (save memory when converting in batches).")
|
||||||
parser.add_argument("--only_convert_finetuned_models",
|
parser.add_argument("--only_convert_finetuned_models",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help = "Only convert finetuned models.")
|
help = "Only convert finetuned models.")
|
||||||
@@ -246,4 +251,5 @@ if __name__ == "__main__":
|
|||||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||||
compare_with_pt_model=args.compare_with_pt_model,
|
compare_with_pt_model=args.compare_with_pt_model,
|
||||||
use_cached_models=args.use_cached_models,
|
use_cached_models=args.use_cached_models,
|
||||||
|
remove_cached_files=args.remove_cached_files,
|
||||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
||||||
|
|||||||
Reference in New Issue
Block a user