splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests

This commit is contained in:
thomwolf
2019-01-29 10:31:42 +01:00
parent 5456d82311
commit 98c96fb1a7
7 changed files with 66 additions and 44 deletions

View File

@@ -52,6 +52,14 @@ TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ModuleNotFoundError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model