Fix importing unofficial TF models with extra optimizer weights

This commit is contained in:
monologg
2020-01-27 23:39:44 +09:00
committed by Lysandre Debut
parent d7dabfeff5
commit 73368963b2
4 changed files with 19 additions and 4 deletions

View File

@@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model