splitting position and tokens embeddings in OpenAI GPT - updating tf imports - tests
This commit is contained in:
@@ -15,23 +15,23 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch OpenAI GPT model."""
|
||||
|
||||
import os
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,6 +42,8 @@ WEIGHTS_NAME = "pytorch_model.bin"
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
import re
|
||||
import numpy as np
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
@@ -50,18 +52,24 @@ def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
# Thsi as used when we had a single embedding matrix for positions and tokens
|
||||
# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
# del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
assert model.tokens_embed.weight.shape == init_params[1].shape
|
||||
assert model.positions_embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
e.args += (model.embed.weight.shape, init_params[0].shape)
|
||||
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
|
||||
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
|
||||
raise
|
||||
|
||||
model.embed.weight.data = torch.from_numpy(init_params[0])
|
||||
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
|
||||
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
|
||||
names.pop(0)
|
||||
# Pop position and token embedding arrays
|
||||
init_params.pop(0)
|
||||
init_params.pop(0)
|
||||
|
||||
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
|
||||
@@ -584,8 +592,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_positions
|
||||
self.embed = nn.Embedding(total_embeddings_size, config.n_embd)
|
||||
num_tokens = config.vocab_size + config.n_special
|
||||
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
|
||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
@@ -598,30 +607,32 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
# Update config
|
||||
self.config.n_special = num_special_tokens
|
||||
# # Build new embeddings and initialize
|
||||
old_embed = self.embed
|
||||
self.embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
|
||||
old_embed = self.tokens_embed
|
||||
self.tokens_embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
|
||||
# Initialize all new embeddings (in particular the special tokens)
|
||||
self.init_weights(self.embed)
|
||||
self.init_weights(self.tokens_embed)
|
||||
# Copy word and positional embeddings from the previous weights
|
||||
self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
self.tokens_embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :]
|
||||
self.tokens_embed.weight.data[-self.config.n_positions :, :] = old_embed.weight.data[-self.config.n_positions :, :]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
if position_ids is None:
|
||||
start = self.config.vocab_size + self.config.n_special
|
||||
end = start + input_ids.size(-1)
|
||||
position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
||||
# This was used when we had a single embedding matrice from position and token embeddings
|
||||
# start = self.config.vocab_size + self.config.n_special
|
||||
# end = start + input_ids.size(-1)
|
||||
# position_ids = torch.arange(start, end, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.embed(input_ids)
|
||||
position_embeds = self.embed(position_ids)
|
||||
inputs_embeds = self.tokens_embed(input_ids)
|
||||
position_embeds = self.positions_embed(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.embed(token_type_ids)
|
||||
token_type_embeds = self.tokens_embed(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
# Add the position information to the input embeddings
|
||||
@@ -694,13 +705,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
" Update input and output embeddings with new embedding matrice "
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
@@ -780,14 +791,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.embed.weight, config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
" Update input and output embeddings with new embedding matrice "
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
||||
|
||||
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
|
||||
Reference in New Issue
Block a user