From 1d87b37d100c69ff3b2c1a5dfd271b6cf777176e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 6 Dec 2019 15:30:09 +0100 Subject: [PATCH] updating --- .../convert_pytorch_checkpoint_to_tf2.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformers/convert_pytorch_checkpoint_to_tf2.py b/transformers/convert_pytorch_checkpoint_to_tf2.py index d20eafe2e9..2c419888e8 100644 --- a/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -119,10 +119,10 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file tf_inputs = tf.constant(inputs_list) tfo = tf_model(tf_inputs, training=False) # build the network - pt_model = pt_model_class(config) - pt_model.load_state_dict(torch.load(pytorch_checkpoint_path, map_location='cpu'), - strict-False) - pt_model.eval() + state_dict = torch.load(pytorch_checkpoint_path, map_location='cpu') + pt_model = pt_model_class.from_pretrained(pretrained_model_name_or_path=None, + config=config, + state_dict=state_dict) pt_inputs = torch.tensor(inputs_list) with torch.no_grad(): @@ -140,7 +140,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, - compare_with_pt_model=False, use_cached_models=False, only_convert_finetuned_models=False): + compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False): assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" if args_model_type is None: @@ -188,13 +188,15 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, model_shortc if os.path.isfile(model_shortcut_name): model_shortcut_name = 'converted_model' + convert_pt_checkpoint_to_tf(model_type=model_type, pytorch_checkpoint_path=model_file, config_file=config_file, tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + '-tf_model.h5'), compare_with_pt_model=compare_with_pt_model) - os.remove(config_file) - os.remove(model_file) + if remove_cached_files: + os.remove(config_file) + os.remove(model_file) if __name__ == "__main__": @@ -227,6 +229,9 @@ if __name__ == "__main__": parser.add_argument("--use_cached_models", action='store_true', help = "Use cached models if possible instead of updating to latest checkpoint versions.") + parser.add_argument("--remove_cached_files", + action='store_true', + help = "Remove pytorch models after conversion (save memory when converting in batches).") parser.add_argument("--only_convert_finetuned_models", action='store_true', help = "Only convert finetuned models.") @@ -246,4 +251,5 @@ if __name__ == "__main__": config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None, compare_with_pt_model=args.compare_with_pt_model, use_cached_models=args.use_cached_models, + remove_cached_files=args.remove_cached_files, only_convert_finetuned_models=args.only_convert_finetuned_models)