Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644)

* add datacollator and dataset for next sentence prediction task

* bug fix (numbers of special tokens & truncate sequences)

* bug fix (+ dict inputs support for data collator)

* add padding for nsp data collator; renamed cached files to avoid conflict.

* add test for nsp data collator

* Style

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Huang Lianzhe
2020-08-31 20:25:00 +08:00
committed by GitHub
parent 895d394669
commit 2de7ee0385
5 changed files with 307 additions and 1 deletions

View File

@@ -9,11 +9,13 @@ if is_torch_available():
from transformers import (
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
default_data_collator,
)
@@ -150,3 +152,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
def test_nsp(self):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))