From 9f81f1cba8a5f6ffc3c449909489343555745df5 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Mon, 7 Oct 2019 12:30:19 -0400 Subject: [PATCH] fix convert pt_to_tf2 for custom weights --- transformers/convert_pytorch_checkpoint_to_tf2.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index d8a48e9dcd..b7e0e79183 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -173,10 +173,12 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc else: model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) - convert_pt_checkpoint_to_tf(model_type, - model_file, - config_file, - os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), + if os.path.isfile(model_shortcut_name): + model_shortcut_name = 'converted_model' + convert_pt_checkpoint_to_tf(model_type=model_type, + pytorch_checkpoint_path=model_file, + config_file=config_file, + tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), compare_with_pt_model=compare_with_pt_model) os.remove(config_file) os.remove(model_file)