Albert pretrain datasets/ datacollator (#6168)

* add dataset for albert pretrain

* datacollator for albert pretrain

* naming, comprehension, file reading change

* data cleaning is no needed after this modification

* delete prints

* fix a bug

* file structure change

* add tests for albert datacollator

* remove random seed

* add back len and get item function

* sample file for testing and test code added

* format change for black

* more format change

* Style

* var assignment issue resolve

* add back wrongly deleted DataCollatorWithPadding in init file

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Yu Liu
2020-09-10 04:56:29 -07:00
committed by GitHub
parent 49e9be0639
commit 762cba3bda
6 changed files with 490 additions and 2 deletions

View File

@@ -11,9 +11,11 @@ if is_torch_available():
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
default_data_collator,
@@ -21,6 +23,7 @@ if is_torch_available():
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
@require_torch
@@ -168,3 +171,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
def test_sop(self):
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
data_collator = DataCollatorForSOP(tokenizer)
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))