Fixes to the TensorFlow conversion tool
This commit is contained in:
@@ -57,7 +57,7 @@ class InputFeatures(object):
|
|||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, seq_length, tokenizer):
|
def convert_examples_to_features(examples, seq_length, tokenizer):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputFeature`s."""
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (ex_index, example) in enumerate(examples):
|
for (ex_index, example) in enumerate(examples):
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
|||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
# which are not required for using pretrained model
|
# which are not required for using pretrained model
|
||||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||||
print("Skipping {}".format("/".join(name)))
|
print("Skipping {}".format("/".join(name)))
|
||||||
continue
|
continue
|
||||||
pointer = model
|
pointer = model
|
||||||
@@ -92,7 +92,11 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
|||||||
elif l[0] == 'output_weights':
|
elif l[0] == 'output_weights':
|
||||||
pointer = getattr(pointer, 'weight')
|
pointer = getattr(pointer, 'weight')
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
pointer = getattr(pointer, l[0])
|
pointer = getattr(pointer, l[0])
|
||||||
|
except AttributeError:
|
||||||
|
print("Skipping {}".format("/".join(name)))
|
||||||
|
continue
|
||||||
if len(l) >= 2:
|
if len(l) >= 2:
|
||||||
num = int(l[1])
|
num = int(l[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
|
|||||||
Reference in New Issue
Block a user