splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests
This commit is contained in:
@@ -14,7 +14,7 @@ def main():
|
||||
else:
|
||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||
try:
|
||||
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
PYTORCH_DUMP_OUTPUT)
|
||||
else:
|
||||
try:
|
||||
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
|
||||
@@ -18,13 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
from pytorch_pretrained_bert.modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||
# Construct model
|
||||
@@ -67,5 +64,5 @@ if __name__ == "__main__":
|
||||
"This specifies the model architecture.")
|
||||
args = parser.parse_args()
|
||||
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.openai_config_file)
|
||||
args.openai_config_file,
|
||||
args.pytorch_dump_folder_path)
|
||||
|
||||
@@ -25,7 +25,7 @@ import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
|
||||
@@ -52,6 +52,14 @@ TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
|
||||
@@ -15,23 +15,23 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT model."""
|
||||
|
||||
import os
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
import re
|
||||
import numpy as np
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
# Thsi as used when we had a single embedding matrix for positions and tokens
|
||||
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
# del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
assert model.tokens_embed.weight.shape == init_params[1].shape
|
||||
assert model.positions_embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
|
||||
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
|
||||
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
# Pop position and token embedding arrays
|
||||
init_params.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
|
||||
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
|
||||
num_tokens = config.vocab_size + config.n_special
|
||||
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
|
||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
# Update config
|
||||
self.config.n_special = num_special_tokens
|
||||
# # Build new embeddings and initialize
|
||||
old_embed = self.embed
|
||||
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
|
||||
old_embed = self.tokens_embed
|
||||
self.tokens_embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
|
||||
# Initialize all new embeddings (in particular the special tokens)
|
||||
self.init_weights(self.embed)
|
||||
self.init_weights(self.tokens_embed)
|
||||
# Copy word and positional embeddings from the previous weights
|
||||
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
if position_ids is None:
|
||||
start = self.config.vocab_size + self.config.n_special
|
||||
end = start + input_ids.size(-1)
|
||||
position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
||||
# This was used when we had a single embedding matrice from position and token embeddings
|
||||
# start = self.config.vocab_size + self.config.n_special
|
||||
# end = start + input_ids.size(-1)
|
||||
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.embed(input_ids)
|
||||
position_embeds = self.embed(position_ids)
|
||||
inputs_embeds = self.tokens_embed(input_ids)
|
||||
position_embeds = self.positions_embed(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.embed(token_type_ids)
|
||||
token_type_embeds = self.tokens_embed(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
# Add the position information to the input embeddings
|
||||
@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
" Update input and output embeddings with new embedding matrice "
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
" Update input and output embeddings with new embedding matrice "
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||
|
||||
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
|
||||
@@ -121,6 +121,13 @@ def build_tf_to_pytorch_map(model, config):
|
||||
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
try:
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_special=1,
|
||||
n_ctx=33,
|
||||
n_positions=33,
|
||||
n_embd=32,
|
||||
n_layer=5,
|
||||
n_head=4,
|
||||
@@ -61,7 +61,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_special = n_special
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
@@ -80,12 +80,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
position_ids = None
|
||||
if self.use_position_ids:
|
||||
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_ctx)
|
||||
position_ids = position_ids + self.n_special + self.vocab_size
|
||||
position_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.vocab_size + self.n_special
|
||||
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
|
||||
|
||||
mc_labels = None
|
||||
@@ -98,7 +97,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
|
||||
config = OpenAIGPTConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_ctx=self.n_ctx,
|
||||
n_positions=self.n_positions,
|
||||
n_special=self.n_special,
|
||||
n_embd=self.n_embd,
|
||||
n_layer=self.n_layer,
|
||||
@@ -139,7 +138,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_openai_lm_head_output(self, result):
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
@@ -164,7 +163,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
||||
return outputs
|
||||
|
||||
def check_openai_double_heads_output(self, result):
|
||||
total_voc = self.n_ctx + self.n_special + self.vocab_size
|
||||
total_voc = self.n_special + self.vocab_size
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()),
|
||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||
|
||||
Reference in New Issue
Block a user