add special tokens to gpt-2
This commit is contained in:
@@ -107,6 +107,7 @@ class GPT2Config(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size_or_config_json_file=50257,
|
vocab_size_or_config_json_file=50257,
|
||||||
|
n_special=0,
|
||||||
n_positions=1024,
|
n_positions=1024,
|
||||||
n_ctx=1024,
|
n_ctx=1024,
|
||||||
n_embd=768,
|
n_embd=768,
|
||||||
@@ -119,6 +120,7 @@ class GPT2Config(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
|
||||||
|
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
|
||||||
n_positions: Number of positional embeddings.
|
n_positions: Number of positional embeddings.
|
||||||
n_ctx: Size of the causal mask (usually same as n_positions).
|
n_ctx: Size of the causal mask (usually same as n_positions).
|
||||||
n_embd: Dimensionality of the embeddings and hidden states.
|
n_embd: Dimensionality of the embeddings and hidden states.
|
||||||
@@ -137,6 +139,7 @@ class GPT2Config(object):
|
|||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
elif isinstance(vocab_size_or_config_json_file, int):
|
elif isinstance(vocab_size_or_config_json_file, int):
|
||||||
self.vocab_size = vocab_size_or_config_json_file
|
self.vocab_size = vocab_size_or_config_json_file
|
||||||
|
self.n_special = n_special
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
self.n_embd = n_embd
|
self.n_embd = n_embd
|
||||||
@@ -150,6 +153,10 @@ class GPT2Config(object):
|
|||||||
"or the path to a pretrained model config file (str)"
|
"or the path to a pretrained model config file (str)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens_embeddings(self):
|
||||||
|
return self.vocab_size + self.n_special
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, json_object):
|
def from_dict(cls, json_object):
|
||||||
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
|
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
|
||||||
@@ -290,11 +297,12 @@ class GPT2LMHead(nn.Module):
|
|||||||
def __init__(self, model_embeddings_weights, config):
|
def __init__(self, model_embeddings_weights, config):
|
||||||
super(GPT2LMHead, self).__init__()
|
super(GPT2LMHead, self).__init__()
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
embed_shape = model_embeddings_weights.shape
|
||||||
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
def set_embeddings_weights(self, model_embeddings_weights):
|
def set_embeddings_weights(self, model_embeddings_weights):
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
|
||||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
@@ -345,7 +353,7 @@ class GPT2PreTrainedModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def set_tied(self):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def init_weights(self, module):
|
def init_weights(self, module):
|
||||||
@@ -475,14 +483,32 @@ class GPT2PreTrainedModel(nn.Module):
|
|||||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure we are still sharing the output and input embeddings after loading weights
|
# Add additional embeddings for special tokens if needed
|
||||||
model.set_tied()
|
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
||||||
|
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class GPT2Model(GPT2PreTrainedModel):
|
class GPT2Model(GPT2PreTrainedModel):
|
||||||
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
|
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
|
||||||
|
|
||||||
|
GPT-2 use a single embedding matrix to store the word and special embeddings.
|
||||||
|
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
|
||||||
|
Special tokens need to be trained during the fine-tuning if you use them.
|
||||||
|
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
|
||||||
|
|
||||||
|
The embeddings are ordered as follow in the token embeddings matrice:
|
||||||
|
[0, ----------------------
|
||||||
|
... -> word embeddings
|
||||||
|
config.vocab_size - 1, ______________________
|
||||||
|
config.vocab_size,
|
||||||
|
... -> special embeddings
|
||||||
|
config.vocab_size + config.n_special - 1] ______________________
|
||||||
|
|
||||||
|
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
|
||||||
|
total_tokens_embeddings = config.vocab_size + config.n_special
|
||||||
|
You should use the associate indices to index the embeddings.
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
config: a GPT2Config class instance with the configuration to build a new model
|
config: a GPT2Config class instance with the configuration to build a new model
|
||||||
|
|
||||||
@@ -529,6 +555,20 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
|
" Update input embeddings with new embedding matrice if needed "
|
||||||
|
if self.config.n_special == num_special_tokens:
|
||||||
|
return
|
||||||
|
# Update config
|
||||||
|
self.config.n_special = num_special_tokens
|
||||||
|
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||||
|
old_embed = self.wte
|
||||||
|
self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||||
|
self.wte.to(old_embed.weight.device)
|
||||||
|
self.init_weights(self.wte)
|
||||||
|
# Copy word embeddings from the previous weights
|
||||||
|
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
|
||||||
if past is None:
|
if past is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
@@ -610,9 +650,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_tied(self):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
""" Make sure we are sharing the embeddings
|
""" Update input and output embeddings with new embedding matrice
|
||||||
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
|
||||||
@@ -687,9 +729,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_tied(self):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
""" Make sure we are sharing the embeddings
|
""" Update input and output embeddings with new embedding matrice
|
||||||
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
|
||||||
|
|||||||
@@ -344,11 +344,12 @@ class OpenAIGPTLMHead(nn.Module):
|
|||||||
def __init__(self, model_embeddings_weights, config):
|
def __init__(self, model_embeddings_weights, config):
|
||||||
super(OpenAIGPTLMHead, self).__init__()
|
super(OpenAIGPTLMHead, self).__init__()
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
embed_shape = model_embeddings_weights.shape
|
||||||
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
def set_embeddings_weights(self, model_embeddings_weights):
|
def set_embeddings_weights(self, model_embeddings_weights):
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
|
||||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
@@ -592,8 +593,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(OpenAIGPTModel, self).__init__(config)
|
super(OpenAIGPTModel, self).__init__(config)
|
||||||
num_tokens = config.vocab_size + config.n_special
|
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
||||||
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
|
|
||||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||||
self.drop = nn.Dropout(config.embd_pdrop)
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
block = Block(config.n_ctx, config, scale=True)
|
block = Block(config.n_ctx, config, scale=True)
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
|
n_special=1,
|
||||||
n_positions=33,
|
n_positions=33,
|
||||||
n_embd=32,
|
n_embd=32,
|
||||||
n_layer=5,
|
n_layer=5,
|
||||||
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
self.use_token_type_ids = use_token_type_ids
|
self.use_token_type_ids = use_token_type_ids
|
||||||
self.use_labels = use_labels
|
self.use_labels = use_labels
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.n_special = n_special
|
||||||
self.n_positions = n_positions
|
self.n_positions = n_positions
|
||||||
self.n_embd = n_embd
|
self.n_embd = n_embd
|
||||||
self.n_layer = n_layer
|
self.n_layer = n_layer
|
||||||
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
|
total_num_tokens = self.vocab_size + self.n_special
|
||||||
|
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
|
||||||
|
|
||||||
position_ids = None
|
position_ids = None
|
||||||
if self.use_position_ids:
|
if self.use_position_ids:
|
||||||
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
config = GPT2Config(
|
config = GPT2Config(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
|
n_special=self.n_special,
|
||||||
n_positions=self.n_positions,
|
n_positions=self.n_positions,
|
||||||
n_embd=self.n_embd,
|
n_embd=self.n_embd,
|
||||||
n_layer=self.n_layer,
|
n_layer=self.n_layer,
|
||||||
@@ -130,7 +134,7 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_gpt2_lm_head_output(self, result):
|
def check_gpt2_lm_head_output(self, result):
|
||||||
total_voc = self.vocab_size
|
total_voc = self.n_special + self.vocab_size
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].size()),
|
list(result["lm_logits"].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
@@ -157,7 +161,7 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_gpt2_double_heads_output(self, result):
|
def check_gpt2_double_heads_output(self, result):
|
||||||
total_voc = self.vocab_size
|
total_voc = self.n_special + self.vocab_size
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].size()),
|
list(result["lm_logits"].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
[self.batch_size, self.n_choices, self.seq_length, total_voc])
|
||||||
|
|||||||
Reference in New Issue
Block a user