From 2de7ee0385bee4134ca894a208fa3a2aaf7d5371 Mon Sep 17 00:00:00 2001 From: Huang Lianzhe Date: Mon, 31 Aug 2020 20:25:00 +0800 Subject: [PATCH] Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644) * add datacollator and dataset for next sentence prediction task * bug fix (numbers of special tokens & truncate sequences) * bug fix (+ dict inputs support for data collator) * add padding for nsp data collator; renamed cached files to avoid conflict. * add test for nsp data collator * Style Co-authored-by: Lysandre Debut Co-authored-by: Lysandre --- src/transformers/__init__.py | 2 + src/transformers/data/data_collator.py | 198 ++++++++++++++++++ src/transformers/data/datasets/__init__.py | 2 +- .../data/datasets/language_modeling.py | 88 ++++++++ tests/test_data_collator.py | 18 ++ 5 files changed, 307 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9558fb457e..502da78555 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -200,6 +200,7 @@ if is_torch_available(): from .data.data_collator import ( DataCollator, DataCollatorForLanguageModeling, + DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorWithPadding, default_data_collator, @@ -211,6 +212,7 @@ if is_torch_available(): SquadDataset, SquadDataTrainingArguments, TextDataset, + TextDatasetForNextSentencePrediction, ) from .generation_utils import top_k_top_p_filtering from .modeling_albert import ( diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index b14d06d4fb..ceb36ed74f 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union @@ -327,3 +328,200 @@ class DataCollatorForPermutationLanguageModeling: ) & masked_indices[i] return inputs, perm_mask, target_mapping, labels + + +@dataclass +class DataCollatorForNextSentencePrediction: + """ + Data collator used for language modeling. + - collates batches of tensors, honoring their tokenizer's pad_token + - preprocesses batches for masked language modeling + """ + + tokenizer: PreTrainedTokenizer + mlm: bool = True + block_size: int = 512 + short_seq_probability: float = 0.1 + nsp_probability: float = 0.5 + mlm_probability: float = 0.15 + + def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + if isinstance(examples[0], (dict, BatchEncoding)): + examples = [e["input_ids"] for e in examples] + + input_ids = [] + segment_ids = [] + attention_masks = [] + nsp_labels = [] + + for i, doc in enumerate(examples): + input_id, segment_id, attention_mask, label = self.create_examples_from_document(doc, i, examples) + input_ids.extend(input_id) + segment_ids.extend(segment_id) + attention_masks.extend(attention_mask) + nsp_labels.extend(label) + if self.mlm: + input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids)) + else: + input_ids = self._tensorize_batch(input_ids) + + return { + "input_ids": input_ids, + "attention_mask": self._tensorize_batch(attention_masks), + "token_type_ids": self._tensorize_batch(segment_ids), + "masked_lm_labels": mlm_labels if self.mlm else None, + "next_sentence_label": torch.tensor(nsp_labels), + } + + def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: + length_of_first = examples[0].size(0) + are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) + if are_tensors_same_length: + return torch.stack(examples, dim=0) + else: + if self.tokenizer._pad_token is None: + raise ValueError( + "You are attempting to pad samples but the tokenizer you are using" + f" ({self.tokenizer.__class__.__name__}) does not have one." + ) + return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + def create_examples_from_document( + self, document: List[List[int]], doc_index: int, examples: List[List[List[int]]] + ): + """Creates examples for a single document.""" + + max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True) + + # We *usually* want to fill up the entire sequence since we are padding + # to `block_size` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `block_size` is a hard limit. + target_seq_length = max_num_tokens + if random.random() < self.short_seq_probability: + target_seq_length = random.randint(2, max_num_tokens) + + current_chunk = [] # a buffer stored current working segments + current_length = 0 + i = 0 + input_ids = [] + segment_ids = [] + attention_masks = [] + labels = [] + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into the `A` + # (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = random.randint(1, len(current_chunk) - 1) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + + if len(current_chunk) == 1 or random.random() < self.nsp_probability: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + + # This should rarely go for more than one iteration for large + # corpora. However, just to be careful, we try to make sure that + # the random document is not the same as the document + # we're processing. + for _ in range(10): + random_document_index = random.randint(0, len(examples) - 1) + if random_document_index != doc_index: + break + + random_document = examples[random_document_index] + random_start = random.randint(0, len(random_document) - 1) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we "put them back" so + # they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_random_next = False + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences( + tokens_a, + tokens_b, + num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens, + truncation_strategy="longest_first", + ) + + input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) + attention_mask = [1] * len(input_id) + segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) + assert len(input_id) <= self.block_size + + # pad + while len(input_id) < self.block_size: + input_id.append(0) + attention_mask.append(0) + segment_id.append(0) + + input_ids.append(torch.tensor(input_id)) + segment_ids.append(torch.tensor(segment_id)) + attention_masks.append(torch.tensor(attention_mask)) + labels.append(torch.tensor(1 if is_random_next else 0)) + + current_chunk = [] + current_length = 0 + + i += 1 + + return input_ids, segment_ids, attention_masks, labels + + def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + + if self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." + ) + + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + probability_matrix = torch.full(labels.shape, self.mlm_probability) + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + if self.tokenizer._pad_token is not None: + padding_mask = labels.eq(self.tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels diff --git a/src/transformers/data/datasets/__init__.py b/src/transformers/data/datasets/__init__.py index ca2ab15e43..f4e2aac5e9 100644 --- a/src/transformers/data/datasets/__init__.py +++ b/src/transformers/data/datasets/__init__.py @@ -3,5 +3,5 @@ # module, but to preserve other warnings. So, don't check this module at all. from .glue import GlueDataset, GlueDataTrainingArguments -from .language_modeling import LineByLineTextDataset, TextDataset +from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction from .squad import SquadDataset, SquadDataTrainingArguments diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 71a5950031..1a377a60b1 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -109,3 +109,91 @@ class LineByLineTextDataset(Dataset): def __getitem__(self, i) -> torch.Tensor: return torch.tensor(self.examples[i], dtype=torch.long) + + +class TextDatasetForNextSentencePrediction(Dataset): + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + file_path: str, + block_size: int, + overwrite_cache=False, + ): + assert os.path.isfile(file_path), f"Input file path {file_path} not found" + + block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True) + + directory, filename = os.path.split(file_path) + cached_features_file = os.path.join( + directory, + "cached_nsp_{}_{}_{}".format( + tokenizer.__class__.__name__, + str(block_size), + filename, + ), + ) + + self.tokenizer = tokenizer + self.examples = [] + + # Make sure only the first process in distributed training processes the dataset, + # and the others will use the cache. + lock_path = cached_features_file + ".lock" + + # Input file format: + # (1) One sentence per line. These should ideally be actual sentences, not + # entire paragraphs or arbitrary spans of text. (Because we use the + # sentence boundaries for the "next sentence prediction" task). + # (2) Blank lines between documents. Document boundaries are needed so + # that the "next sentence prediction" task doesn't span between documents. + # + # Example: + # I am very happy. + # Here is the second sentence. + # + # A new document. + + with FileLock(lock_path): + if os.path.exists(cached_features_file) and not overwrite_cache: + start = time.time() + with open(cached_features_file, "rb") as handle: + self.examples = pickle.load(handle) + logger.info( + f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start + ) + else: + logger.info(f"Creating features from dataset file at {directory}") + + self.examples = [[]] + with open(file_path, encoding="utf-8") as f: + while True: + line = f.readline() + if not line: + break + line = line.strip() + + # Empty lines are used as document delimiters + if not line and len(self.examples[-1]) != 0: + self.examples.append([]) + tokens = tokenizer.tokenize(line) + tokens = tokenizer.convert_tokens_to_ids(tokens) + if tokens: + self.examples[-1].append(tokens) + + start = time.time() + with open(cached_features_file, "wb") as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + logger.info( + "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start + ) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + return self.examples[i] diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index 41b3b371b9..2ec65e5738 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -9,11 +9,13 @@ if is_torch_available(): from transformers import ( DataCollatorForLanguageModeling, + DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, GlueDataset, GlueDataTrainingArguments, LineByLineTextDataset, TextDataset, + TextDatasetForNextSentencePrediction, default_data_collator, ) @@ -150,3 +152,19 @@ class DataCollatorIntegrationTest(unittest.TestCase): with self.assertRaises(ValueError): # Expect error due to odd sequence length data_collator(example) + + def test_nsp(self): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + data_collator = DataCollatorForNextSentencePrediction(tokenizer) + + dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) + examples = [dataset[i] for i in range(len(dataset))] + batch = data_collator(examples) + self.assertIsInstance(batch, dict) + + # Since there are randomly generated false samples, the total number of samples is not fixed. + total_samples = batch["input_ids"].shape[0] + self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512))) + self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))