From 20d07b3a7f8a1f04aa94e3a7f2ec03fad641de70 Mon Sep 17 00:00:00 2001 From: Donatas Repecka Date: Tue, 13 Nov 2018 16:56:25 +0200 Subject: [PATCH] Excluding AdamWeightDecayOptimizer internal variables from restoring --- convert_tf_checkpoint_to_pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) mode change 100644 => 100755 convert_tf_checkpoint_to_pytorch.py diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py old mode 100644 new mode 100755 index dfcdbee42d..eeebb3728e --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -68,11 +68,17 @@ def convert(): arrays.append(array) for name, array in zip(names, arrays): - name = name[5:] # skip "bert/" + if not name.startswith("bert"): + print("Skipping {}".format(name)) + continue + else: + name = name.replace("bert/", "") # skip "bert/" print("Loading {}".format(name)) name = name.split('/') - if name[0] in ['redictions', 'eq_relationship']: - print("Skipping") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": + print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: