embeddings resizing + tie_weights
This commit is contained in:
@@ -156,7 +156,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size_or_config_json_file=40478,
|
||||
n_special=0,
|
||||
n_positions=512,
|
||||
n_ctx=512,
|
||||
n_embd=768,
|
||||
@@ -190,7 +189,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.n_special = n_special
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
@@ -216,10 +214,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@property
|
||||
def total_tokens_embeddings(self):
|
||||
return self.vocab_size + self.n_special
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.n_embd
|
||||
@@ -355,34 +349,6 @@ class Block(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class OpenAIGPTLMHead(nn.Module):
|
||||
""" Language Model Head for the transformer """
|
||||
|
||||
def __init__(self, model_embeddings_weights, config):
|
||||
super(OpenAIGPTLMHead, self).__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.vocab_size = config.vocab_size
|
||||
self.predict_special_tokens = config.predict_special_tokens
|
||||
self.torchscript = config.torchscript
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
|
||||
self.predict_special_tokens = predict_special_tokens
|
||||
|
||||
if self.torchscript:
|
||||
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
|
||||
else:
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
if not self.predict_special_tokens:
|
||||
lm_logits = lm_logits[..., :self.vocab_size]
|
||||
return lm_logits
|
||||
|
||||
|
||||
class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
@@ -408,36 +374,6 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `config.json` a configuration file for the model
|
||||
. a series of NumPy files containing OpenAI TensorFlow trained weights
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
|
||||
"""
|
||||
num_special_tokens = kwargs.get('num_special_tokens', None)
|
||||
kwargs.pop('num_special_tokens', None)
|
||||
|
||||
model = super(PreTrainedModel, cls).from_pretrained(pretrained_model_name_or_path, pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
# Add additional embeddings for special tokens if needed
|
||||
# This step also make sure we are still sharing the output and input embeddings after loading weights
|
||||
model.set_num_special_tokens(num_special_tokens)
|
||||
return model
|
||||
|
||||
|
||||
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
|
||||
@@ -457,13 +393,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + config.n_special - 1] ______________________
|
||||
config.vocab_size + n_special - 1] ______________________
|
||||
|
||||
where ``total_tokens_embeddings`` can be obtained as ``config.total_tokens_embeddings`` and is:
|
||||
where ``total_tokens_embeddings`` is:
|
||||
|
||||
::
|
||||
|
||||
total_tokens_embeddings = config.vocab_size + config.n_special
|
||||
total_tokens_embeddings = config.vocab_size + n_special
|
||||
|
||||
You should use the associated indices to index the embeddings.
|
||||
|
||||
@@ -485,34 +421,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
||||
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens=None):
|
||||
"""
|
||||
Update input embeddings with new embedding matrice if needed
|
||||
|
||||
Args:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
|
||||
TODO Lysandre filled Args
|
||||
|
||||
"""
|
||||
if num_special_tokens is None or self.config.n_special == num_special_tokens:
|
||||
return
|
||||
# Update config
|
||||
self.config.n_special = num_special_tokens
|
||||
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
|
||||
old_embed = self.tokens_embed
|
||||
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
|
||||
self.tokens_embed.to(old_embed.weight.device)
|
||||
self.init_weights(self.tokens_embed)
|
||||
# Copy word embeddings from the previous weights
|
||||
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
self.tokens_embed = self._get_resized_embeddings(self.tokens_embed, new_num_tokens)
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
@@ -657,24 +574,20 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings
|
||||
|
||||
Args:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
|
||||
Defaults to True.
|
||||
|
||||
TODO Lysandre filled Args
|
||||
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||
input_embeddings = self.transformer.tokens_embed.weight
|
||||
if self.config.torchscript:
|
||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.lm_head.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||
"""
|
||||
@@ -747,13 +660,13 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
config.vocab_size - 1, ______________________
|
||||
config.vocab_size,
|
||||
... -> special embeddings
|
||||
config.vocab_size + config.n_special - 1] ______________________
|
||||
config.vocab_size + n_special - 1] ______________________
|
||||
|
||||
where ``total_tokens_embeddings`` can be obtained as ``config.total_tokens_embeddings`` and is:
|
||||
where ``total_tokens_embeddings`` is:
|
||||
|
||||
::
|
||||
|
||||
total_tokens_embeddings = config.vocab_size + config.n_special
|
||||
total_tokens_embeddings = config.vocab_size + .n_special
|
||||
|
||||
You should use the associate indices to index the embeddings.
|
||||
|
||||
@@ -773,24 +686,21 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.multiple_choice_head = SequenceSummary(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
|
||||
""" Update input and output embeddings with new embedding matrix. Make sure we are sharing the embeddings.
|
||||
|
||||
Args:
|
||||
num_special_tokens: Special tokens to be added to the embedding matrix
|
||||
predict_special_tokens: if set to True, the model will try and predict the specified ``num_special_tokens``.
|
||||
Defaults to True.
|
||||
|
||||
TODO Lysandre filled Args
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||
input_embeddings = self.transformer.tokens_embed.weight
|
||||
if self.config.torchscript:
|
||||
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.lm_head.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None):
|
||||
|
||||
Reference in New Issue
Block a user