updating examples
This commit is contained in:
@@ -49,17 +49,17 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array.squeeze())
|
||||
@@ -90,7 +90,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user