Fix importing unofficial TF models with extra optimizer weights
This commit is contained in:
@@ -117,7 +117,13 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
|
|
||||||
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
||||||
if "adam_m" in name or "adam_v" in name or "global_step" in name:
|
if (
|
||||||
|
"adam_m" in name
|
||||||
|
or "adam_v" in name
|
||||||
|
or "AdamWeightDecayOptimizer" in name
|
||||||
|
or "AdamWeightDecayOptimizer_1" in name
|
||||||
|
or "global_step" in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
|
|||||||
@@ -79,7 +79,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
|||||||
name = txt_name.split("/")
|
name = txt_name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
tf_weights.pop(txt_name, None)
|
tf_weights.pop(txt_name, None)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -76,7 +76,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
|
|||||||
name = name.split("/")
|
name = name.split("/")
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
if any(
|
||||||
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||||
|
for n in name
|
||||||
|
):
|
||||||
logger.info("Skipping {}".format("/".join(name)))
|
logger.info("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
|
|||||||
Reference in New Issue
Block a user