From 4dd5cf22073f86f559479945fcec568190267fb5 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 5 Jun 2020 15:20:29 -0400 Subject: [PATCH] Fix argument label (#4792) * Fix argument label * Fix test --- src/transformers/data/data_collator.py | 2 +- tests/test_trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index b8f3f571b6..7cd095651c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -91,7 +91,7 @@ class DataCollatorForLanguageModeling(DataCollator): batch = self._tensorize_batch(examples) if self.mlm: inputs, labels = self.mask_tokens(batch) - return {"input_ids": inputs, "masked_lm_labels": labels} + return {"input_ids": inputs, "labels": labels} else: return {"input_ids": batch, "labels": batch} diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 023f7ba6b0..1717030376 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -74,14 +74,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): batch = data_collator.collate_batch(examples) self.assertIsInstance(batch, dict) self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107))) - self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((31, 107))) + self.assertEqual(batch["labels"].shape, torch.Size((31, 107))) dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) examples = [dataset[i] for i in range(len(dataset))] batch = data_collator.collate_batch(examples) self.assertIsInstance(batch, dict) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) - self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((2, 512))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) @require_torch