updating examples

This commit is contained in:
thomwolf
2019-07-11 12:03:08 +02:00
parent 50b7e52a7f
commit 4fef5919a5
10 changed files with 116 additions and 150 deletions

View File

@@ -73,17 +73,17 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
@@ -93,7 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
print("Skipping {}".format("/".join(name)))
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
@@ -113,7 +113,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try:
pointer = getattr(pointer, l[0])
except AttributeError:
print("Skipping {}".format("/".join(name)))
logger.info("Skipping {}".format("/".join(name)))
continue
if len(l) >= 2:
num = int(l[1])
@@ -127,7 +127,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model