From 63276b76d4fb54d096b491e89632859aed6b4364 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 21 Sep 2020 10:31:26 -0400 Subject: [PATCH] Fix #7284 (#7289) --- src/transformers/data/data_collator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index a2d001034b..9768c7d1f3 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -434,13 +434,15 @@ class DataCollatorForNextSentencePrediction: else: input_ids = self._tensorize_batch(input_ids) - return { + result = { "input_ids": input_ids, "attention_mask": self._tensorize_batch(attention_masks), "token_type_ids": self._tensorize_batch(segment_ids), - "masked_lm_labels": mlm_labels if self.mlm else None, "next_sentence_label": torch.tensor(nsp_labels), } + if self.mlm: + result["masked_lm_labels"] = mlm_labels + return result def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: length_of_first = examples[0].size(0)