Merge branch 'xlnet' into doc-sphinx
This commit is contained in:
@@ -73,17 +73,17 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
@@ -93,7 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
@@ -113,7 +113,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
except AttributeError:
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
@@ -127,7 +127,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user