Excluding AdamWeightDecayOptimizer internal variables from restoring
This commit is contained in:
12
convert_tf_checkpoint_to_pytorch.py
Normal file → Executable file
12
convert_tf_checkpoint_to_pytorch.py
Normal file → Executable file
@@ -68,11 +68,17 @@ def convert():
|
|||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name[5:] # skip "bert/"
|
if not name.startswith("bert"):
|
||||||
|
print("Skipping {}".format(name))
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
name = name.replace("bert/", "") # skip "bert/"
|
||||||
print("Loading {}".format(name))
|
print("Loading {}".format(name))
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
if name[0] in ['redictions', 'eq_relationship']:
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
print("Skipping")
|
# which are not required for using pretrained model
|
||||||
|
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
|
||||||
|
print("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
for m_name in name:
|
for m_name in name:
|
||||||
|
|||||||
Reference in New Issue
Block a user