fixing Adam weights skip in TF convert script

This commit is contained in:
thomwolf
2018-12-09 16:17:11 -05:00
parent 91aab2a6d3
commit 13bf0d4659

View File

@@ -50,7 +50,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
name = name.split('/') name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if name[-1] in ["adam_v", "adam_m"]: if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name))) print("Skipping {}".format("/".join(name)))
continue continue
pointer = model pointer = model