Fix argument label (#4792)

* Fix argument label

* Fix test
This commit is contained in:
Sylvain Gugger
2020-06-05 15:20:29 -04:00
committed by GitHub
parent 3723f30a18
commit 4dd5cf2207
2 changed files with 3 additions and 3 deletions

View File

@@ -74,14 +74,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
@require_torch