Fixes to the TensorFlow conversion tool

This commit is contained in:
Mike Arpaia
2019-04-01 12:53:51 -06:00
parent ec5c1d6134
commit 8b5c63e4de
2 changed files with 7 additions and 3 deletions

View File

@@ -76,7 +76,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
@@ -92,7 +92,11 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
try:
pointer = getattr(pointer, l[0])
except AttributeError:
print("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]