updating
This commit is contained in:
@@ -119,10 +119,10 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
tf_inputs = tf.constant(inputs_list)
|
||||
tfo = tf_model(tf_inputs, training=False) # build the network
|
||||
|
||||
pt_model = pt_model_class(config)
|
||||
pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'),
|
||||
strict-False)
|
||||
pt_model.eval()
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu')
|
||||
pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None,
|
||||
config=config,
|
||||
state_dict=state_dict)
|
||||
|
||||
pt_inputs = torch.tensor(inputs_list)
|
||||
with torch.no_grad():
|
||||
@@ -140,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
|
||||
|
||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None,
|
||||
compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False):
|
||||
compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False):
|
||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||
|
||||
if args_model_type is None:
|
||||
@@ -188,13 +188,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
||||
|
||||
if os.path.isfile(model_shortcut_name):
|
||||
model_shortcut_name = 'converted_model'
|
||||
|
||||
convert_pt_checkpoint_to_tf(model_type=model_type,
|
||||
pytorch_checkpoint_path=model_file,
|
||||
config_file=config_file,
|
||||
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
||||
compare_with_pt_model=compare_with_pt_model)
|
||||
os.remove(config_file)
|
||||
os.remove(model_file)
|
||||
if remove_cached_files:
|
||||
os.remove(config_file)
|
||||
os.remove(model_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -227,6 +229,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--use_cached_models",
|
||||
action='store_true',
|
||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
||||
parser.add_argument("--remove_cached_files",
|
||||
action='store_true',
|
||||
help = "Remove pytorch models after conversion (save memory when converting in batches).")
|
||||
parser.add_argument("--only_convert_finetuned_models",
|
||||
action='store_true',
|
||||
help = "Only convert finetuned models.")
|
||||
@@ -246,4 +251,5 @@ if __name__ == "__main__":
|
||||
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
||||
compare_with_pt_model=args.compare_with_pt_model,
|
||||
use_cached_models=args.use_cached_models,
|
||||
remove_cached_files=args.remove_cached_files,
|
||||
only_convert_finetuned_models=args.only_convert_finetuned_models)
|
||||
|
||||
Reference in New Issue
Block a user