diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index d4d47a3bd6..dfcdbee42d 100755 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -26,14 +26,35 @@ import numpy as np from modeling import BertConfig, BertModel +parser = argparse.ArgumentParser() -def convert(config_path, ckpt_path, out_path=None): +## Required parameters +parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") +parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") +parser.add_argument("--pytorch_dump_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + +args = parser.parse_args() + +def convert(): # Initialise PyTorch model - config = BertConfig.from_json_file(config_path) + config = BertConfig.from_json_file(args.bert_config_file) model = BertModel(config) # Load weights from TF model - path = ckpt_path + path = args.tf_checkpoint_path print("Converting TensorFlow checkpoint from {}".format(path)) init_vars = tf.train.list_variables(path) @@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None): arrays.append(array) for name, array in zip(names, arrays): - if not name.startswith("bert"): - print("Skipping {}".format(name)) - continue - else: - name = name.replace("bert/", "") # skip "bert/" + name = name[5:] # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": - print("Skipping {}".format("/".join(name))) + if name[0] in ['redictions', 'eq_relationship']: + print("Skipping") continue pointer = model for m_name in name: @@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None): pointer.data = torch.from_numpy(array) # Save pytorch-model - if out_path is not None: - torch.save(model.state_dict(), out_path) - return model - + torch.save(model.state_dict(), args.pytorch_dump_path) if __name__ == "__main__": - parser = argparse.ArgumentParser() - - ## Required parameters - parser.add_argument("--tf_checkpoint_path", - default=None, - type=str, - required=True, - help="Path the TensorFlow checkpoint path.") - parser.add_argument("--bert_config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") - parser.add_argument("--pytorch_dump_path", - default=None, - type=str, - required=False, - help="Path to the output PyTorch model.") - - args = parser.parse_args() - print(args) - convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path) + convert()