clean up pr
This commit is contained in:
@@ -26,14 +26,35 @@ import numpy as np
|
|||||||
|
|
||||||
from modeling import BertConfig, BertModel
|
from modeling import BertConfig, BertModel
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
def convert(config_path, ckpt_path, out_path=None):
|
## Required parameters
|
||||||
|
parser.add_argument("--tf_checkpoint_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path the TensorFlow checkpoint path.")
|
||||||
|
parser.add_argument("--bert_config_file",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
parser.add_argument("--pytorch_dump_path",
|
||||||
|
default = None,
|
||||||
|
type = str,
|
||||||
|
required = True,
|
||||||
|
help = "Path to the output PyTorch model.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def convert():
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = BertConfig.from_json_file(config_path)
|
config = BertConfig.from_json_file(args.bert_config_file)
|
||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
|
|
||||||
# Load weights from TF model
|
# Load weights from TF model
|
||||||
path = ckpt_path
|
path = args.tf_checkpoint_path
|
||||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||||
|
|
||||||
init_vars = tf.train.list_variables(path)
|
init_vars = tf.train.list_variables(path)
|
||||||
@@ -47,17 +68,11 @@ def convert(config_path, ckpt_path, out_path=None):
|
|||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
if not name.startswith("bert"):
|
name = name[5:] # skip "bert/"
|
||||||
print("Skipping {}".format(name))
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
name = name.replace("bert/", "") # skip "bert/"
|
|
||||||
print("Loading {}".format(name))
|
print("Loading {}".format(name))
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
if name[0] in ['redictions', 'eq_relationship']:
|
||||||
# which are not required for using pretrained model
|
print("Skipping")
|
||||||
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
|
|
||||||
print("Skipping {}".format("/".join(name)))
|
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
@@ -84,32 +99,7 @@ def convert(config_path, ckpt_path, out_path=None):
|
|||||||
pointer.data = torch.from_numpy(array)
|
pointer.data = torch.from_numpy(array)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
if out_path is not None:
|
torch.save(model.state_dict(), args.pytorch_dump_path)
|
||||||
torch.save(model.state_dict(), out_path)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
convert()
|
||||||
|
|
||||||
## Required parameters
|
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path the TensorFlow checkpoint path.")
|
|
||||||
parser.add_argument("--bert_config_file",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="The config json file corresponding to the pre-trained BERT model. \n"
|
|
||||||
"This specifies the model architecture.")
|
|
||||||
parser.add_argument("--pytorch_dump_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
help="Path to the output PyTorch model.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
print(args)
|
|
||||||
convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user