fix convert pt_to_tf2 for custom weights
This commit is contained in:
@@ -173,10 +173,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc
|
|||||||
else:
|
else:
|
||||||
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
||||||
|
|
||||||
convert_pt_checkpoint_to_tf(model_type,
|
if os.path.isfile(model_shortcut_name):
|
||||||
model_file,
|
model_shortcut_name = 'converted_model'
|
||||||
config_file,
|
convert_pt_checkpoint_to_tf(model_type=model_type,
|
||||||
os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
pytorch_checkpoint_path=model_file,
|
||||||
|
config_file=config_file,
|
||||||
|
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'),
|
||||||
compare_with_pt_model=compare_with_pt_model)
|
compare_with_pt_model=compare_with_pt_model)
|
||||||
os.remove(config_file)
|
os.remove(config_file)
|
||||||
os.remove(model_file)
|
os.remove(model_file)
|
||||||
|
|||||||
Reference in New Issue
Block a user