Fix label name in DataCollatorForNextSentencePrediction test (#8048)
This commit is contained in:
@@ -175,7 +175,7 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
total_samples = batch["input_ids"].shape[0]
|
total_samples = batch["input_ids"].shape[0]
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
||||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
||||||
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
|
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
|
||||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user