updating code organization to fix imports
This commit is contained in:
@@ -24,7 +24,7 @@ import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||
# Construct model
|
||||
@@ -46,66 +46,6 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'g':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'b':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'w':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -25,7 +25,7 @@ import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .modeling import BertConfig, BertForPreTraining
|
||||
from .modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
@@ -40,57 +40,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -27,7 +27,7 @@ import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
||||
|
||||
# We do this to be able to load the python 2 datasets pickles
|
||||
@@ -38,74 +38,6 @@ data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||
sys.modules['data_utils'] = data_utils
|
||||
sys.modules['vocabulary'] = data_utils
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
""" A map of modules from TF to PyTorch.
|
||||
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||
"""
|
||||
tf_to_pt_map = {}
|
||||
# Embeddings cutoffs
|
||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'lookup_table': embed_l.weight,
|
||||
layer_str + 'proj_W': proj_l
|
||||
})
|
||||
|
||||
# Transformer blocks
|
||||
for i, b in enumerate(model.layers):
|
||||
layer_str = "transformer/layer_%d/" % i
|
||||
tf_to_pt_map.update({
|
||||
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
|
||||
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
|
||||
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
|
||||
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
|
||||
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
|
||||
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
|
||||
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
|
||||
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
|
||||
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
|
||||
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
|
||||
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
||||
})
|
||||
|
||||
# Adaptive Softmax
|
||||
tf_to_pt_map.update({
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
|
||||
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
|
||||
model.crit.out_layers,
|
||||
model.crit.out_projs,
|
||||
config.tie_projs)):
|
||||
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
||||
if config.tie_weight:
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'b': out_l.bias})
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# I don't think this is implemented in the TF code
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'lookup_table': out_l.weight,
|
||||
layer_str + 'b': out_l.bias})
|
||||
if not tie_proj:
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'proj': proj_l
|
||||
})
|
||||
|
||||
# Relative positioning biases
|
||||
if config.untie_r:
|
||||
r_r_list = []
|
||||
r_w_list = []
|
||||
for b in model.layers:
|
||||
r_r_list.append(b.dec_attn.r_r_bias)
|
||||
r_w_list.append(b.dec_attn.r_w_bias)
|
||||
else:
|
||||
r_r_list = [model.r_r_bias]
|
||||
r_w_list = [model.r_w_bias]
|
||||
tf_to_pt_map.update({
|
||||
'transformer/r_r_bias': r_r_list,
|
||||
'transformer/r_w_bias': r_w_list})
|
||||
return tf_to_pt_map
|
||||
|
||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
transfo_xl_config_file,
|
||||
pytorch_dump_folder_path,
|
||||
@@ -150,54 +82,6 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
assert name in tf_weights
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if 'kernel' in name or 'proj' in name:
|
||||
array = np.transpose(array)
|
||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||
# Here we will split the TF weigths
|
||||
assert len(pointer) == array.shape[0]
|
||||
for i, p_i in enumerate(pointer):
|
||||
arr_i = array[i, ...]
|
||||
try:
|
||||
assert p_i.shape == arr_i.shape
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
else:
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
tf_weights.pop(name, None)
|
||||
tf_weights.pop(name + '/Adam', None)
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
## Required parameters
|
||||
|
||||
@@ -33,7 +33,6 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,6 +49,59 @@ CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m"] for n in name):
|
||||
print("Skipping {}".format("/".join(name)))
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
|
||||
@@ -32,7 +32,6 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,6 +39,67 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.h
|
||||
CONFIG_NAME = "openai_gpt_config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
offsets = np.cumsum([np.prod(shape) for shape in shapes])
|
||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
name = name[6:] # skip "model/"
|
||||
assert name[-2:] == ":0"
|
||||
name = name[:-2]
|
||||
name = name.split('/')
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
|
||||
l = re.split(r'(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'g':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'b':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'w':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
else:
|
||||
pointer = getattr(pointer, l[0])
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
@@ -37,7 +37,6 @@ from torch.nn.parameter import Parameter
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .file_utils import cached_path
|
||||
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +50,123 @@ CONFIG_NAME = 'transfo_xl_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
""" A map of modules from TF to PyTorch.
|
||||
This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
|
||||
"""
|
||||
tf_to_pt_map = {}
|
||||
# Embeddings cutoffs
|
||||
for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
|
||||
layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'lookup_table': embed_l.weight,
|
||||
layer_str + 'proj_W': proj_l
|
||||
})
|
||||
|
||||
# Transformer blocks
|
||||
for i, b in enumerate(model.layers):
|
||||
layer_str = "transformer/layer_%d/" % i
|
||||
tf_to_pt_map.update({
|
||||
layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
|
||||
layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
|
||||
layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
|
||||
layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
|
||||
layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
|
||||
layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
|
||||
layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
|
||||
layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
|
||||
layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
|
||||
layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
|
||||
layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
|
||||
})
|
||||
|
||||
# Adaptive Softmax
|
||||
tf_to_pt_map.update({
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
|
||||
"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias})
|
||||
for i, (out_l, proj_l, tie_proj) in enumerate(zip(
|
||||
model.crit.out_layers,
|
||||
model.crit.out_projs,
|
||||
config.tie_projs)):
|
||||
layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
|
||||
if config.tie_weight:
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'b': out_l.bias})
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# I don't think this is implemented in the TF code
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'lookup_table': out_l.weight,
|
||||
layer_str + 'b': out_l.bias})
|
||||
if not tie_proj:
|
||||
tf_to_pt_map.update({
|
||||
layer_str + 'proj': proj_l
|
||||
})
|
||||
|
||||
# Relative positioning biases
|
||||
if config.untie_r:
|
||||
r_r_list = []
|
||||
r_w_list = []
|
||||
for b in model.layers:
|
||||
r_r_list.append(b.dec_attn.r_r_bias)
|
||||
r_w_list.append(b.dec_attn.r_w_bias)
|
||||
else:
|
||||
r_r_list = [model.r_r_bias]
|
||||
r_w_list = [model.r_w_bias]
|
||||
tf_to_pt_map.update({
|
||||
'transformer/r_r_bias': r_r_list,
|
||||
'transformer/r_w_bias': r_w_list})
|
||||
return tf_to_pt_map
|
||||
|
||||
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
print("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
tf_weights[name] = array
|
||||
|
||||
for name, pointer in tf_to_pt_map.items():
|
||||
assert name in tf_weights
|
||||
array = tf_weights[name]
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if 'kernel' in name or 'proj' in name:
|
||||
array = np.transpose(array)
|
||||
if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1:
|
||||
# Here we will split the TF weigths
|
||||
assert len(pointer) == array.shape[0]
|
||||
for i, p_i in enumerate(pointer):
|
||||
arr_i = array[i, ...]
|
||||
try:
|
||||
assert p_i.shape == arr_i.shape
|
||||
except AssertionError as e:
|
||||
e.args += (p_i.shape, arr_i.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {} for layer {}".format(name, i))
|
||||
p_i.data = torch.from_numpy(arr_i)
|
||||
else:
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
tf_weights.pop(name, None)
|
||||
tf_weights.pop(name + '/Adam', None)
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
|
||||
class TransfoXLConfig(object):
|
||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||
"""
|
||||
|
||||
@@ -291,7 +291,7 @@ if __name__ == '__main__':
|
||||
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||
|
||||
sampler = LogUniformSampler(n_vocab, unique=True)
|
||||
sampler = LogUniformSampler(n_vocab, n_sample)#, unique=True)
|
||||
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||
|
||||
# print('true_probs', true_probs.numpy().tolist())
|
||||
|
||||
Reference in New Issue
Block a user