From 8b5c63e4deffad8c1c421caee8fef4bb97881f70 Mon Sep 17 00:00:00 2001 From: Mike Arpaia Date: Mon, 1 Apr 2019 12:53:51 -0600 Subject: [PATCH] Fixes to the TensorFlow conversion tool --- examples/extract_features.py | 2 +- pytorch_pretrained_bert/modeling.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/extract_features.py b/examples/extract_features.py index 0d59aa7e81..13384a9d69 100644 --- a/examples/extract_features.py +++ b/examples/extract_features.py @@ -57,7 +57,7 @@ class InputFeatures(object): def convert_examples_to_features(examples, seq_length, tokenizer): - """Loads a data file into a list of `InputBatch`s.""" + """Loads a data file into a list of `InputFeature`s.""" features = [] for (ex_index, example) in enumerate(examples): diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index b92f3a87f1..938636142f 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -76,7 +76,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path): name = name.split('/') # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m"] for n in name): + if any(n in ["adam_v", "adam_m", "global_step"] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model @@ -92,7 +92,11 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path): elif l[0] == 'output_weights': pointer = getattr(pointer, 'weight') else: - pointer = getattr(pointer, l[0]) + try: + pointer = getattr(pointer, l[0]) + except AttributeError: + print("Skipping {}".format("/".join(name))) + continue if len(l) >= 2: num = int(l[1]) pointer = pointer[num]