From 4e0f24348fcc9902664951677ffc7c8cc171443d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 17 Oct 2019 09:41:53 +0200 Subject: [PATCH] document the MLM modification + raise exception on MLM training with encoder-decoder --- transformers/modeling_bert.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index e717031dcb..2553bc0efb 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel): prediction_scores = self.cls(sequence_output) outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here + + # Although this may seem awkward, BertForMaskedLM supports two scenarios: + # 1. If a tensor that contains the indices of masked labels is provided, + # the cross-entropy is the MLM cross-entropy that measures the likelihood + # of predictions for masked words. + # 2. If encoder hidden states are provided we are in a causal situation where we + # try to predict the next word for each input in the encoder. + if masked_lm_labels is not None and encoder_hidden_states is not None: + raise AttributeError("Masked LM training with an encoder-decoder is not supported.") + if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) + loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) outputs = (masked_lm_loss,) + outputs if encoder_hidden_states is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - - # shift predictions scores and input ids by one before computing loss + # we are doing next-token prediction; shift prediction scores and input ids by one prediction_scores = prediction_scores[:, :-1, :] input_ids = input_ids[:, 1:, :] + loss_fct = CrossEntropyLoss(ignore_index=-1) seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1)) outputs = (seq2seq_loss,) + outputs - return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions) @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,