Albert pretrain datasets/ datacollator (#6168)
* add dataset for albert pretrain * datacollator for albert pretrain * naming, comprehension, file reading change * data cleaning is no needed after this modification * delete prints * fix a bug * file structure change * add tests for albert datacollator * remove random seed * add back len and get item function * sample file for testing and test code added * format change for black * more format change * Style * var assignment issue resolve * add back wrongly deleted DataCollatorWithPadding in init file * Style Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -11,9 +11,11 @@ if is_torch_available():
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForNextSentencePrediction,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
DataCollatorForSOP,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
LineByLineWithSOPTextDataset,
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
default_data_collator,
|
||||
@@ -21,6 +23,7 @@ if is_torch_available():
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
||||
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -168,3 +171,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
||||
|
||||
def test_sop(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
||||
data_collator = DataCollatorForSOP(tokenizer)
|
||||
|
||||
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
|
||||
# Since there are randomly generated false samples, the total number of samples is not fixed.
|
||||
total_samples = batch["input_ids"].shape[0]
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))
|
||||
|
||||
Reference in New Issue
Block a user