embeddings resizing + tie_weights

This commit is contained in:
thomwolf
2019-07-12 00:02:49 +02:00
parent 50e62a4cb4
commit bd404735a7
15 changed files with 196 additions and 332 deletions

View File

@@ -507,23 +507,17 @@ class BertPredictionHeadTransform(nn.Module):
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
self.torchscript = config.torchscript
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
self.decoder = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
if self.torchscript:
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
else:
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
@@ -532,9 +526,9 @@ class BertLMPredictionHead(nn.Module):
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
@@ -552,9 +546,9 @@ class BertOnlyNSPHead(nn.Module):
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
@@ -619,6 +613,11 @@ class BertModel(BertPreTrainedModel):
self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
@@ -750,9 +749,20 @@ class BertForPreTraining(BertPreTrainedModel):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.cls = BertPreTrainingHeads(config)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.bert.embeddings.word_embeddings.weight
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None):
@@ -845,9 +855,20 @@ class BertForMaskedLM(BertPreTrainedModel):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.cls = BertOnlyMLMHead(config)
self.apply(self.init_weights)
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
input_embeddings = self.bert.embeddings.word_embeddings.weight
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
"""