embeddings resizing + tie_weights
This commit is contained in:
@@ -507,23 +507,17 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
def __init__(self, config):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
self.torchscript = config.torchscript
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
bert_model_embedding_weights.size(0),
|
||||
self.decoder = nn.Linear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False)
|
||||
|
||||
if self.torchscript:
|
||||
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
|
||||
else:
|
||||
self.decoder.weight = bert_model_embedding_weights
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
@@ -532,9 +526,9 @@ class BertLMPredictionHead(nn.Module):
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
def __init__(self, config):
|
||||
super(BertOnlyMLMHead, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
self.predictions = BertLMPredictionHead(config)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
@@ -552,9 +546,9 @@ class BertOnlyNSPHead(nn.Module):
|
||||
|
||||
|
||||
class BertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
def __init__(self, config):
|
||||
super(BertPreTrainingHeads, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
self.predictions = BertLMPredictionHead(config)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
@@ -619,6 +613,11 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
@@ -750,9 +749,20 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
super(BertForPreTraining, self).__init__(config)
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.bert.embeddings.word_embeddings.weight
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
|
||||
next_sentence_label=None, head_mask=None):
|
||||
@@ -845,9 +855,20 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
super(BertForMaskedLM, self).__init__(config)
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
""" Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
input_embeddings = self.bert.embeddings.word_embeddings.weight
|
||||
if self.config.torchscript:
|
||||
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
|
||||
else:
|
||||
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user