Fix label name in DataCollatorForNextSentencePrediction test (#8048)

This commit is contained in:
Sylvain Gugger
2020-10-26 09:23:12 -04:00
committed by GitHub
parent 8bbe8247f1
commit 077478637d

View File

@@ -175,7 +175,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
total_samples = batch["input_ids"].shape[0] total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512))) self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512))) self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
@slow @slow