From dc580dd4c720c5daefe7411f604b6908da99681e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 16:56:36 +0200 Subject: [PATCH] add lm_labels for the LM cross-entropy --- transformers/modeling_bert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index aa022bac8a..d10f32c1fa 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -819,7 +819,7 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert.embeddings.word_embeddings) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, - masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None): + masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None): outputs = self.bert(input_ids, attention_mask=attention_mask, @@ -840,7 +840,7 @@ class BertForMaskedLM(BertPreTrainedModel): # of predictions for masked words. # 2. If encoder hidden states are provided we are in a causal situation where we # try to predict the next word for each input in the encoder. - if masked_lm_labels is not None and encoder_hidden_states is not None: + if masked_lm_labels is not None and lm_labels is not None: raise AttributeError("Masked LM training with an encoder-decoder is not supported.") if masked_lm_labels is not None: @@ -848,12 +848,12 @@ class BertForMaskedLM(BertPreTrainedModel): masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) outputs = (masked_lm_loss,) + outputs - if encoder_hidden_states is not None: + if lm_labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one prediction_scores = prediction_scores[:, :-1, :] - input_ids = input_ids[:, 1:, :] + lm_labels = lm_labels[:, 1:, :] loss_fct = CrossEntropyLoss(ignore_index=-1) - seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1)) + seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) outputs = (seq2seq_loss,) + outputs return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)