Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644)
* add datacollator and dataset for next sentence prediction task * bug fix (numbers of special tokens & truncate sequences) * bug fix (+ dict inputs support for data collator) * add padding for nsp data collator; renamed cached files to avoid conflict. * add test for nsp data collator * Style Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -9,11 +9,13 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForNextSentencePrediction,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
default_data_collator,
|
||||
)
|
||||
|
||||
@@ -150,3 +152,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
# Expect error due to odd sequence length
|
||||
data_collator(example)
|
||||
|
||||
def test_nsp(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
|
||||
|
||||
dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
|
||||
# Since there are randomly generated false samples, the total number of samples is not fixed.
|
||||
total_samples = batch["input_ids"].shape[0]
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
|
||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
||||
|
||||
Reference in New Issue
Block a user