updating code organization to fix imports
This commit is contained in:
@@ -24,7 +24,7 @@ import argparse
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
from .modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||||
# Construct model
|
# Construct model
|
||||||
@@ -46,66 +46,6 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
f.write(config.to_json_string())
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
|
||||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
|
||||||
"""
|
|
||||||
print("Loading weights...")
|
|
||||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
|
||||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
|
||||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
|
||||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
|
||||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
|
||||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
|
||||||
|
|
||||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
|
||||||
del init_params[1]
|
|
||||||
init_params = [arr.squeeze() for arr in init_params]
|
|
||||||
|
|
||||||
try:
|
|
||||||
assert model.embed.weight.shape == init_params[0].shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
|
||||||
raise
|
|
||||||
|
|
||||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
|
||||||
names.pop(0)
|
|
||||||
init_params.pop(0)
|
|
||||||
|
|
||||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
|
||||||
name = name[6:] # skip "model/"
|
|
||||||
assert name[-2:] == ":0"
|
|
||||||
name = name[:-2]
|
|
||||||
name = name.split('/')
|
|
||||||
pointer = model
|
|
||||||
for m_name in name:
|
|
||||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
|
||||||
l = re.split(r'(\d+)', m_name)
|
|
||||||
else:
|
|
||||||
l = [m_name]
|
|
||||||
if l[0] == 'g':
|
|
||||||
pointer = getattr(pointer, 'weight')
|
|
||||||
elif l[0] == 'b':
|
|
||||||
pointer = getattr(pointer, 'bias')
|
|
||||||
elif l[0] == 'w':
|
|
||||||
pointer = getattr(pointer, 'weight')
|
|
||||||
else:
|
|
||||||
pointer = getattr(pointer, l[0])
|
|
||||||
if len(l) >= 2:
|
|
||||||
num = int(l[1])
|
|
||||||
pointer = pointer[num]
|
|
||||||
try:
|
|
||||||
assert pointer.shape == array.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
try:
|
|
||||||
assert pointer.shape == array.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
|
||||||
pointer.data = torch.from_numpy(array)
|
|
||||||
return model
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import tensorflow as tf
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .modeling import BertConfig, BertForPreTraining
|
from .modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
@@ -40,57 +40,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||||
torch.save(model.state_dict(), pytorch_dump_path)
|
torch.save(model.state_dict(), pytorch_dump_path)
|
||||||
|
|
||||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
|
||||||
""" Load tf checkpoints in a pytorch model
|
|
||||||
"""
|
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
||||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
|
||||||
# Load weights from TF model
|
|
||||||
init_vars = tf.train.list_variables(tf_path)
|
|
||||||
names = []
|
|
||||||
arrays = []
|
|
||||||
for name, shape in init_vars:
|
|
||||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
|
||||||
array = tf.train.load_variable(tf_path, name)
|
|
||||||
names.append(name)
|
|
||||||
arrays.append(array)
|
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
|
||||||
name = name.split('/')
|
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
||||||
# which are not required for using pretrained model
|
|
||||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
|
||||||
print("Skipping {}".format("/".join(name)))
|
|
||||||
continue
|
|
||||||
pointer = model
|
|
||||||
for m_name in name:
|
|
||||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
|
||||||
l = re.split(r'_(\d+)', m_name)
|
|
||||||
else:
|
|
||||||
l = [m_name]
|
|
||||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
|
||||||
pointer = getattr(pointer, 'weight')
|
|
||||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
|
||||||
pointer = getattr(pointer, 'bias')
|
|
||||||
elif l[0] == 'output_weights':
|
|
||||||
pointer = getattr(pointer, 'weight')
|
|
||||||
else:
|
|
||||||
pointer = getattr(pointer, l[0])
|
|
||||||
if len(l) >= 2:
|
|
||||||
num = int(l[1])
|
|
||||||
pointer = pointer[num]
|
|
||||||
if m_name[-11:] == '_embeddings':
|
|
||||||
pointer = getattr(pointer, 'weight')
|
|
||||||
elif m_name == 'kernel':
|
|
||||||
array = np.transpose(array)
|
|
||||||
try:
|
|
||||||
assert pointer.shape == array.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
|
||||||
pointer.data = torch.from_numpy(array)
|
|
||||||
return model
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import tensorflow as tf
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl
|
||||||
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
||||||
|
|
||||||
# We do this to be able to load the python 2 datasets pickles
|
# We do this to be able to load the python 2 datasets pickles
|
||||||
@@ -38,74 +38,6 @@ data_utils.Corpus = data_utils.TransfoXLCorpus
|
|||||||
sys.modules['data_utils'] = data_utils
|
sys.modules['data_utils'] = data_utils
|
||||||
sys.modules['vocabulary'] = data_utils
|
sys.modules['vocabulary'] = data_utils
|
||||||
|
|
||||||
def build_tf_to_pytorch_map(model, config):
|
|
||||||
""" A map of modules from TF to PyTorch.
|
|
||||||
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
|
||||||
"""
|
|
||||||
tf_to_pt_map = {}
|
|
||||||
# Embeddings cutoffs
|
|
||||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
|
||||||
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
layer_str + 'lookup_table': embed_l.weight,
|
|
||||||
layer_str + 'proj_W': proj_l
|
|
||||||
})
|
|
||||||
|
|
||||||
# Transformer blocks
|
|
||||||
for i, b in enumerate(model.layers):
|
|
||||||
layer_str = "transformer/layer_%d/" % i
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
|
|
||||||
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
|
|
||||||
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
|
|
||||||
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
|
|
||||||
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
|
|
||||||
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
|
|
||||||
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
|
|
||||||
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
|
|
||||||
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
|
|
||||||
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
|
|
||||||
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Adaptive Softmax
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
|
||||||
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
|
|
||||||
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
|
|
||||||
model.crit.out_layers,
|
|
||||||
model.crit.out_projs,
|
|
||||||
config.tie_projs)):
|
|
||||||
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
|
||||||
if config.tie_weight:
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
layer_str + 'b': out_l.bias})
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
# I don't think this is implemented in the TF code
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
layer_str + 'lookup_table': out_l.weight,
|
|
||||||
layer_str + 'b': out_l.bias})
|
|
||||||
if not tie_proj:
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
layer_str + 'proj': proj_l
|
|
||||||
})
|
|
||||||
|
|
||||||
# Relative positioning biases
|
|
||||||
if config.untie_r:
|
|
||||||
r_r_list = []
|
|
||||||
r_w_list = []
|
|
||||||
for b in model.layers:
|
|
||||||
r_r_list.append(b.dec_attn.r_r_bias)
|
|
||||||
r_w_list.append(b.dec_attn.r_w_bias)
|
|
||||||
else:
|
|
||||||
r_r_list = [model.r_r_bias]
|
|
||||||
r_w_list = [model.r_w_bias]
|
|
||||||
tf_to_pt_map.update({
|
|
||||||
'transformer/r_r_bias': r_r_list,
|
|
||||||
'transformer/r_w_bias': r_w_list})
|
|
||||||
return tf_to_pt_map
|
|
||||||
|
|
||||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||||
transfo_xl_config_file,
|
transfo_xl_config_file,
|
||||||
pytorch_dump_folder_path,
|
pytorch_dump_folder_path,
|
||||||
@@ -150,54 +82,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
f.write(config.to_json_string())
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
|
||||||
""" Load tf checkpoints in a pytorch model
|
|
||||||
"""
|
|
||||||
# Build TF to PyTorch weights loading map
|
|
||||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
|
||||||
|
|
||||||
# Load weights from TF model
|
|
||||||
init_vars = tf.train.list_variables(tf_path)
|
|
||||||
tf_weights = {}
|
|
||||||
for name, shape in init_vars:
|
|
||||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
|
||||||
array = tf.train.load_variable(tf_path, name)
|
|
||||||
tf_weights[name] = array
|
|
||||||
|
|
||||||
for name, pointer in tf_to_pt_map.items():
|
|
||||||
assert name in tf_weights
|
|
||||||
array = tf_weights[name]
|
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
|
||||||
# which are not required for using pretrained model
|
|
||||||
if 'kernel' in name or 'proj' in name:
|
|
||||||
array = np.transpose(array)
|
|
||||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
|
||||||
# Here we will split the TF weigths
|
|
||||||
assert len(pointer) == array.shape[0]
|
|
||||||
for i, p_i in enumerate(pointer):
|
|
||||||
arr_i = array[i, ...]
|
|
||||||
try:
|
|
||||||
assert p_i.shape == arr_i.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (p_i.shape, arr_i.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
|
||||||
p_i.data = torch.from_numpy(arr_i)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
assert pointer.shape == array.shape
|
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
|
||||||
raise
|
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
|
||||||
pointer.data = torch.from_numpy(array)
|
|
||||||
tf_weights.pop(name, None)
|
|
||||||
tf_weights.pop(name + '/Adam', None)
|
|
||||||
tf_weights.pop(name + '/Adam_1', None)
|
|
||||||
|
|
||||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
|
||||||
return model
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
## Required parameters
|
## Required parameters
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -50,6 +49,59 @@ CONFIG_NAME = 'bert_config.json'
|
|||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
|
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||||
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
"""
|
||||||
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
|
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
|
# Load weights from TF model
|
||||||
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
|
names = []
|
||||||
|
arrays = []
|
||||||
|
for name, shape in init_vars:
|
||||||
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||||
|
array = tf.train.load_variable(tf_path, name)
|
||||||
|
names.append(name)
|
||||||
|
arrays.append(array)
|
||||||
|
|
||||||
|
for name, array in zip(names, arrays):
|
||||||
|
name = name.split('/')
|
||||||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
|
# which are not required for using pretrained model
|
||||||
|
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||||
|
print("Skipping {}".format("/".join(name)))
|
||||||
|
continue
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||||
|
l = re.split(r'_(\d+)', m_name)
|
||||||
|
else:
|
||||||
|
l = [m_name]
|
||||||
|
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||||
|
pointer = getattr(pointer, 'bias')
|
||||||
|
elif l[0] == 'output_weights':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
else:
|
||||||
|
pointer = getattr(pointer, l[0])
|
||||||
|
if len(l) >= 2:
|
||||||
|
num = int(l[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
if m_name[-11:] == '_embeddings':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif m_name == 'kernel':
|
||||||
|
array = np.transpose(array)
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Implementation of the gelu activation function.
|
"""Implementation of the gelu activation function.
|
||||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -40,6 +39,67 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
|
|||||||
CONFIG_NAME = "openai_gpt_config.json"
|
CONFIG_NAME = "openai_gpt_config.json"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
|
||||||
|
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||||
|
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||||
|
"""
|
||||||
|
print("Loading weights...")
|
||||||
|
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||||
|
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||||
|
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||||
|
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||||
|
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||||
|
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||||
|
|
||||||
|
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||||
|
del init_params[1]
|
||||||
|
init_params = [arr.squeeze() for arr in init_params]
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert model.embed.weight.shape == init_params[0].shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||||
|
raise
|
||||||
|
|
||||||
|
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||||
|
names.pop(0)
|
||||||
|
init_params.pop(0)
|
||||||
|
|
||||||
|
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||||
|
name = name[6:] # skip "model/"
|
||||||
|
assert name[-2:] == ":0"
|
||||||
|
name = name[:-2]
|
||||||
|
name = name.split('/')
|
||||||
|
pointer = model
|
||||||
|
for m_name in name:
|
||||||
|
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||||
|
l = re.split(r'(\d+)', m_name)
|
||||||
|
else:
|
||||||
|
l = [m_name]
|
||||||
|
if l[0] == 'g':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
elif l[0] == 'b':
|
||||||
|
pointer = getattr(pointer, 'bias')
|
||||||
|
elif l[0] == 'w':
|
||||||
|
pointer = getattr(pointer, 'weight')
|
||||||
|
else:
|
||||||
|
pointer = getattr(pointer, l[0])
|
||||||
|
if len(l) >= 2:
|
||||||
|
num = int(l[1])
|
||||||
|
pointer = pointer[num]
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from torch.nn.parameter import Parameter
|
|||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,6 +50,123 @@ CONFIG_NAME = 'transfo_xl_config.json'
|
|||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
|
def build_tf_to_pytorch_map(model, config):
|
||||||
|
""" A map of modules from TF to PyTorch.
|
||||||
|
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||||
|
"""
|
||||||
|
tf_to_pt_map = {}
|
||||||
|
# Embeddings cutoffs
|
||||||
|
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||||
|
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
layer_str + 'lookup_table': embed_l.weight,
|
||||||
|
layer_str + 'proj_W': proj_l
|
||||||
|
})
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
for i, b in enumerate(model.layers):
|
||||||
|
layer_str = "transformer/layer_%d/" % i
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
|
||||||
|
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
|
||||||
|
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
|
||||||
|
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
|
||||||
|
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
|
||||||
|
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
|
||||||
|
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
|
||||||
|
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
|
||||||
|
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
|
||||||
|
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
|
||||||
|
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Adaptive Softmax
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
||||||
|
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
|
||||||
|
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
|
||||||
|
model.crit.out_layers,
|
||||||
|
model.crit.out_projs,
|
||||||
|
config.tie_projs)):
|
||||||
|
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
||||||
|
if config.tie_weight:
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
layer_str + 'b': out_l.bias})
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
# I don't think this is implemented in the TF code
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
layer_str + 'lookup_table': out_l.weight,
|
||||||
|
layer_str + 'b': out_l.bias})
|
||||||
|
if not tie_proj:
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
layer_str + 'proj': proj_l
|
||||||
|
})
|
||||||
|
|
||||||
|
# Relative positioning biases
|
||||||
|
if config.untie_r:
|
||||||
|
r_r_list = []
|
||||||
|
r_w_list = []
|
||||||
|
for b in model.layers:
|
||||||
|
r_r_list.append(b.dec_attn.r_r_bias)
|
||||||
|
r_w_list.append(b.dec_attn.r_w_bias)
|
||||||
|
else:
|
||||||
|
r_r_list = [model.r_r_bias]
|
||||||
|
r_w_list = [model.r_w_bias]
|
||||||
|
tf_to_pt_map.update({
|
||||||
|
'transformer/r_r_bias': r_r_list,
|
||||||
|
'transformer/r_w_bias': r_w_list})
|
||||||
|
return tf_to_pt_map
|
||||||
|
|
||||||
|
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||||
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
"""
|
||||||
|
# Build TF to PyTorch weights loading map
|
||||||
|
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||||
|
|
||||||
|
# Load weights from TF model
|
||||||
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
|
tf_weights = {}
|
||||||
|
for name, shape in init_vars:
|
||||||
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||||
|
array = tf.train.load_variable(tf_path, name)
|
||||||
|
tf_weights[name] = array
|
||||||
|
|
||||||
|
for name, pointer in tf_to_pt_map.items():
|
||||||
|
assert name in tf_weights
|
||||||
|
array = tf_weights[name]
|
||||||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
|
# which are not required for using pretrained model
|
||||||
|
if 'kernel' in name or 'proj' in name:
|
||||||
|
array = np.transpose(array)
|
||||||
|
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||||
|
# Here we will split the TF weigths
|
||||||
|
assert len(pointer) == array.shape[0]
|
||||||
|
for i, p_i in enumerate(pointer):
|
||||||
|
arr_i = array[i, ...]
|
||||||
|
try:
|
||||||
|
assert p_i.shape == arr_i.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (p_i.shape, arr_i.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||||
|
p_i.data = torch.from_numpy(arr_i)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
assert pointer.shape == array.shape
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pointer.shape, array.shape)
|
||||||
|
raise
|
||||||
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
|
pointer.data = torch.from_numpy(array)
|
||||||
|
tf_weights.pop(name, None)
|
||||||
|
tf_weights.pop(name + '/Adam', None)
|
||||||
|
tf_weights.pop(name + '/Adam_1', None)
|
||||||
|
|
||||||
|
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TransfoXLConfig(object):
|
class TransfoXLConfig(object):
|
||||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ if __name__ == '__main__':
|
|||||||
# sampler = LogUniformSampler(n_vocab, unique=False)
|
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||||
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||||
|
|
||||||
sampler = LogUniformSampler(n_vocab, unique=True)
|
sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
|
||||||
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||||
|
|
||||||
# print('true_probs', true_probs.numpy().tolist())
|
# print('true_probs', true_probs.numpy().tolist())
|
||||||
|
|||||||
Reference in New Issue
Block a user