updating examples

This commit is contained in:
thomwolf
2019-07-11 12:03:08 +02:00
parent 50b7e52a7f
commit 4fef5919a5
10 changed files with 116 additions and 150 deletions

View File

@@ -49,17 +49,17 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger.error("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(gpt2_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array.squeeze())
@@ -90,7 +90,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model