splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests

This commit is contained in:
thomwolf
2019-01-29 10:31:42 +01:00
parent 5456d82311
commit 98c96fb1a7
7 changed files with 66 additions and 44 deletions

View File

@@ -15,23 +15,23 @@
# limitations under the License.
"""PyTorch OpenAI GPT model."""
import os
import collections
import copy
import json
import math
import logging
import math
import os
import shutil
import tarfile
import tempfile
import shutil
import collections
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .file_utils import cached_path
from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
import re
import numpy as np
print("Loading weights...")
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
# Thsi as used when we had a single embedding matrix for positions and tokens
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]
try:
assert model.embed.weight.shape == init_params[0].shape
assert model.tokens_embed.weight.shape == init_params[1].shape
assert model.positions_embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.embed.weight.shape, init_params[0].shape)
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
# Pop position and token embedding arrays
init_params.pop(0)
init_params.pop(0)
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
num_tokens = config.vocab_size + config.n_special
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Update config
self.config.n_special = num_special_tokens
# # Build new embeddings and initialize
old_embed = self.embed
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
old_embed = self.tokens_embed
self.tokens_embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.embed)
self.init_weights(self.tokens_embed)
# Copy word and positional embeddings from the previous weights
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None):
if position_ids is None:
start = self.config.vocab_size + self.config.n_special
end = start + input_ids.size(-1)
position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
# end = start + input_ids.size(-1)
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
inputs_embeds = self.embed(input_ids)
position_embeds = self.embed(position_ids)
inputs_embeds = self.tokens_embed(input_ids)
position_embeds = self.positions_embed(position_ids)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
token_type_embeds = self.embed(token_type_ids)
token_type_embeds = self.tokens_embed(token_type_ids)
else:
token_type_embeds = 0
# Add the position information to the input embeddings
@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input and output embeddings with new embedding matrice "
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)