add lm_labels for the LM cross-entropy
This commit is contained in:
@@ -819,7 +819,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
self.bert.embeddings.word_embeddings)
|
self.bert.embeddings.word_embeddings)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||||
|
|
||||||
outputs = self.bert(input_ids,
|
outputs = self.bert(input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -840,7 +840,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
# of predictions for masked words.
|
# of predictions for masked words.
|
||||||
# 2. If encoder hidden states are provided we are in a causal situation where we
|
# 2. If encoder hidden states are provided we are in a causal situation where we
|
||||||
# try to predict the next word for each input in the encoder.
|
# try to predict the next word for each input in the encoder.
|
||||||
if masked_lm_labels is not None and encoder_hidden_states is not None:
|
if masked_lm_labels is not None and lm_labels is not None:
|
||||||
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
|
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
|
||||||
|
|
||||||
if masked_lm_labels is not None:
|
if masked_lm_labels is not None:
|
||||||
@@ -848,12 +848,12 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
outputs = (masked_lm_loss,) + outputs
|
outputs = (masked_lm_loss,) + outputs
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
if lm_labels is not None:
|
||||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||||
prediction_scores = prediction_scores[:, :-1, :]
|
prediction_scores = prediction_scores[:, :-1, :]
|
||||||
input_ids = input_ids[:, 1:, :]
|
lm_labels = lm_labels[:, 1:, :]
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1))
|
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||||
outputs = (seq2seq_loss,) + outputs
|
outputs = (seq2seq_loss,) + outputs
|
||||||
|
|
||||||
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
||||||
|
|||||||
Reference in New Issue
Block a user