fixing Adam weights skip in TF convert script
This commit is contained in:
@@ -50,7 +50,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if name[-1] in ["adam_v", "adam_m"]:
|
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||||
print("Skipping {}".format("/".join(name)))
|
print("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
|
|||||||
Reference in New Issue
Block a user