add predict_special_tokens option to GPT also
This commit is contained in:
@@ -143,6 +143,7 @@ class OpenAIGPTConfig(object):
|
|||||||
attn_pdrop=0.1,
|
attn_pdrop=0.1,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
predict_special_tokens=True
|
||||||
):
|
):
|
||||||
"""Constructs OpenAIGPTConfig.
|
"""Constructs OpenAIGPTConfig.
|
||||||
|
|
||||||
@@ -165,6 +166,7 @@ class OpenAIGPTConfig(object):
|
|||||||
layer_norm_epsilon: epsilon to use in the layer norm layers
|
layer_norm_epsilon: epsilon to use in the layer norm layers
|
||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
|
predict_special_tokens: should we predict special tokens (when the model has a LM head)
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
@@ -186,6 +188,7 @@ class OpenAIGPTConfig(object):
|
|||||||
self.attn_pdrop = attn_pdrop
|
self.attn_pdrop = attn_pdrop
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
self.layer_norm_epsilon = layer_norm_epsilon
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
self.predict_special_tokens = predict_special_tokens
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"First argument must be either a vocabulary size (int)"
|
"First argument must be either a vocabulary size (int)"
|
||||||
@@ -356,18 +359,21 @@ class OpenAIGPTLMHead(nn.Module):
|
|||||||
def __init__(self, model_embeddings_weights, config):
|
def __init__(self, model_embeddings_weights, config):
|
||||||
super(OpenAIGPTLMHead, self).__init__()
|
super(OpenAIGPTLMHead, self).__init__()
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.predict_special_tokens = config.predict_special_tokens
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||||
self.set_embeddings_weights(model_embeddings_weights)
|
self.set_embeddings_weights(model_embeddings_weights)
|
||||||
|
|
||||||
def set_embeddings_weights(self, model_embeddings_weights):
|
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||||
|
self.predict_special_tokens = predict_special_tokens
|
||||||
embed_shape = model_embeddings_weights.shape
|
embed_shape = model_embeddings_weights.shape
|
||||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||||
|
|
||||||
def forward(self, hidden_state):
|
def forward(self, hidden_state):
|
||||||
# Truncated Language modeling logits (we remove the last token)
|
|
||||||
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
|
||||||
lm_logits = self.decoder(hidden_state)
|
lm_logits = self.decoder(hidden_state)
|
||||||
|
if not self.predict_special_tokens:
|
||||||
|
lm_logits = lm_logits[..., :self.vocab_size]
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
@@ -428,9 +434,6 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||||
@@ -613,7 +616,6 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||||
|
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
# nn.init.normal_(self.embed.weight, std=0.02)
|
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens):
|
||||||
" Update input embeddings with new embedding matrice if needed "
|
" Update input embeddings with new embedding matrice if needed "
|
||||||
@@ -727,12 +729,13 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||||
""" Update input and output embeddings with new embedding matrice
|
""" Update input and output embeddings with new embedding matrice
|
||||||
Make sure we are sharing the embeddings
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||||
@@ -821,12 +824,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||||
self.apply(self.init_weights)
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
def set_num_special_tokens(self, num_special_tokens):
|
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||||
""" Update input and output embeddings with new embedding matrice
|
""" Update input and output embeddings with new embedding matrice
|
||||||
Make sure we are sharing the embeddings
|
Make sure we are sharing the embeddings
|
||||||
"""
|
"""
|
||||||
|
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||||
|
|
||||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user