@@ -91,7 +91,7 @@ class DataCollatorForLanguageModeling(DataCollator):
|
|||||||
batch = self._tensorize_batch(examples)
|
batch = self._tensorize_batch(examples)
|
||||||
if self.mlm:
|
if self.mlm:
|
||||||
inputs, labels = self.mask_tokens(batch)
|
inputs, labels = self.mask_tokens(batch)
|
||||||
return {"input_ids": inputs, "masked_lm_labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
else:
|
else:
|
||||||
return {"input_ids": batch, "labels": batch}
|
return {"input_ids": batch, "labels": batch}
|
||||||
|
|
||||||
|
|||||||
@@ -74,14 +74,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
batch = data_collator.collate_batch(examples)
|
batch = data_collator.collate_batch(examples)
|
||||||
self.assertIsInstance(batch, dict)
|
self.assertIsInstance(batch, dict)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
||||||
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((31, 107)))
|
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
||||||
|
|
||||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||||
examples = [dataset[i] for i in range(len(dataset))]
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
batch = data_collator.collate_batch(examples)
|
batch = data_collator.collate_batch(examples)
|
||||||
self.assertIsInstance(batch, dict)
|
self.assertIsInstance(batch, dict)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||||
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user