updating code organization to fix imports
This commit is contained in:
@@ -32,7 +32,6 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,6 +39,67 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
|
||||
CONFIG_NAME = "openai_gpt_config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'g':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'b':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'w':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
Reference in New Issue
Block a user