Dataset and DataCollator for BERT Next Sentence Prediction (NSP) task (#6644)
* add datacollator and dataset for next sentence prediction task * bug fix (numbers of special tokens & truncate sequences) * bug fix (+ dict inputs support for data collator) * add padding for nsp data collator; renamed cached files to avoid conflict. * add test for nsp data collator * Style Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -200,6 +200,7 @@ if is_torch_available():
|
|||||||
from .data.data_collator import (
|
from .data.data_collator import (
|
||||||
DataCollator,
|
DataCollator,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForNextSentencePrediction,
|
||||||
DataCollatorForPermutationLanguageModeling,
|
DataCollatorForPermutationLanguageModeling,
|
||||||
DataCollatorWithPadding,
|
DataCollatorWithPadding,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
@@ -211,6 +212,7 @@ if is_torch_available():
|
|||||||
SquadDataset,
|
SquadDataset,
|
||||||
SquadDataTrainingArguments,
|
SquadDataTrainingArguments,
|
||||||
TextDataset,
|
TextDataset,
|
||||||
|
TextDatasetForNextSentencePrediction,
|
||||||
)
|
)
|
||||||
from .generation_utils import top_k_top_p_filtering
|
from .generation_utils import top_k_top_p_filtering
|
||||||
from .modeling_albert import (
|
from .modeling_albert import (
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -327,3 +328,200 @@ class DataCollatorForPermutationLanguageModeling:
|
|||||||
) & masked_indices[i]
|
) & masked_indices[i]
|
||||||
|
|
||||||
return inputs, perm_mask, target_mapping, labels
|
return inputs, perm_mask, target_mapping, labels
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForNextSentencePrediction:
|
||||||
|
"""
|
||||||
|
Data collator used for language modeling.
|
||||||
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||||
|
- preprocesses batches for masked language modeling
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizer
|
||||||
|
mlm: bool = True
|
||||||
|
block_size: int = 512
|
||||||
|
short_seq_probability: float = 0.1
|
||||||
|
nsp_probability: float = 0.5
|
||||||
|
mlm_probability: float = 0.15
|
||||||
|
|
||||||
|
def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
|
examples = [e["input_ids"] for e in examples]
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
segment_ids = []
|
||||||
|
attention_masks = []
|
||||||
|
nsp_labels = []
|
||||||
|
|
||||||
|
for i, doc in enumerate(examples):
|
||||||
|
input_id, segment_id, attention_mask, label = self.create_examples_from_document(doc, i, examples)
|
||||||
|
input_ids.extend(input_id)
|
||||||
|
segment_ids.extend(segment_id)
|
||||||
|
attention_masks.extend(attention_mask)
|
||||||
|
nsp_labels.extend(label)
|
||||||
|
if self.mlm:
|
||||||
|
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
|
||||||
|
else:
|
||||||
|
input_ids = self._tensorize_batch(input_ids)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": self._tensorize_batch(attention_masks),
|
||||||
|
"token_type_ids": self._tensorize_batch(segment_ids),
|
||||||
|
"masked_lm_labels": mlm_labels if self.mlm else None,
|
||||||
|
"next_sentence_label": torch.tensor(nsp_labels),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
length_of_first = examples[0].size(0)
|
||||||
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||||
|
if are_tensors_same_length:
|
||||||
|
return torch.stack(examples, dim=0)
|
||||||
|
else:
|
||||||
|
if self.tokenizer._pad_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to pad samples but the tokenizer you are using"
|
||||||
|
f" ({self.tokenizer.__class__.__name__}) does not have one."
|
||||||
|
)
|
||||||
|
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def create_examples_from_document(
|
||||||
|
self, document: List[List[int]], doc_index: int, examples: List[List[List[int]]]
|
||||||
|
):
|
||||||
|
"""Creates examples for a single document."""
|
||||||
|
|
||||||
|
max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
|
||||||
|
|
||||||
|
# We *usually* want to fill up the entire sequence since we are padding
|
||||||
|
# to `block_size` anyways, so short sequences are generally wasted
|
||||||
|
# computation. However, we *sometimes*
|
||||||
|
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
||||||
|
# sequences to minimize the mismatch between pre-training and fine-tuning.
|
||||||
|
# The `target_seq_length` is just a rough target however, whereas
|
||||||
|
# `block_size` is a hard limit.
|
||||||
|
target_seq_length = max_num_tokens
|
||||||
|
if random.random() < self.short_seq_probability:
|
||||||
|
target_seq_length = random.randint(2, max_num_tokens)
|
||||||
|
|
||||||
|
current_chunk = [] # a buffer stored current working segments
|
||||||
|
current_length = 0
|
||||||
|
i = 0
|
||||||
|
input_ids = []
|
||||||
|
segment_ids = []
|
||||||
|
attention_masks = []
|
||||||
|
labels = []
|
||||||
|
while i < len(document):
|
||||||
|
segment = document[i]
|
||||||
|
current_chunk.append(segment)
|
||||||
|
current_length += len(segment)
|
||||||
|
if i == len(document) - 1 or current_length >= target_seq_length:
|
||||||
|
if current_chunk:
|
||||||
|
# `a_end` is how many segments from `current_chunk` go into the `A`
|
||||||
|
# (first) sentence.
|
||||||
|
a_end = 1
|
||||||
|
if len(current_chunk) >= 2:
|
||||||
|
a_end = random.randint(1, len(current_chunk) - 1)
|
||||||
|
|
||||||
|
tokens_a = []
|
||||||
|
for j in range(a_end):
|
||||||
|
tokens_a.extend(current_chunk[j])
|
||||||
|
|
||||||
|
tokens_b = []
|
||||||
|
|
||||||
|
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
|
||||||
|
is_random_next = True
|
||||||
|
target_b_length = target_seq_length - len(tokens_a)
|
||||||
|
|
||||||
|
# This should rarely go for more than one iteration for large
|
||||||
|
# corpora. However, just to be careful, we try to make sure that
|
||||||
|
# the random document is not the same as the document
|
||||||
|
# we're processing.
|
||||||
|
for _ in range(10):
|
||||||
|
random_document_index = random.randint(0, len(examples) - 1)
|
||||||
|
if random_document_index != doc_index:
|
||||||
|
break
|
||||||
|
|
||||||
|
random_document = examples[random_document_index]
|
||||||
|
random_start = random.randint(0, len(random_document) - 1)
|
||||||
|
for j in range(random_start, len(random_document)):
|
||||||
|
tokens_b.extend(random_document[j])
|
||||||
|
if len(tokens_b) >= target_b_length:
|
||||||
|
break
|
||||||
|
# We didn't actually use these segments so we "put them back" so
|
||||||
|
# they don't go to waste.
|
||||||
|
num_unused_segments = len(current_chunk) - a_end
|
||||||
|
i -= num_unused_segments
|
||||||
|
# Actual next
|
||||||
|
else:
|
||||||
|
is_random_next = False
|
||||||
|
for j in range(a_end, len(current_chunk)):
|
||||||
|
tokens_b.extend(current_chunk[j])
|
||||||
|
|
||||||
|
assert len(tokens_a) >= 1
|
||||||
|
assert len(tokens_b) >= 1
|
||||||
|
|
||||||
|
tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
|
||||||
|
tokens_a,
|
||||||
|
tokens_b,
|
||||||
|
num_tokens_to_remove=len(tokens_a) + len(tokens_b) - max_num_tokens,
|
||||||
|
truncation_strategy="longest_first",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_id = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
||||||
|
attention_mask = [1] * len(input_id)
|
||||||
|
segment_id = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
||||||
|
assert len(input_id) <= self.block_size
|
||||||
|
|
||||||
|
# pad
|
||||||
|
while len(input_id) < self.block_size:
|
||||||
|
input_id.append(0)
|
||||||
|
attention_mask.append(0)
|
||||||
|
segment_id.append(0)
|
||||||
|
|
||||||
|
input_ids.append(torch.tensor(input_id))
|
||||||
|
segment_ids.append(torch.tensor(segment_id))
|
||||||
|
attention_masks.append(torch.tensor(attention_mask))
|
||||||
|
labels.append(torch.tensor(1 if is_random_next else 0))
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return input_ids, segment_ids, attention_masks, labels
|
||||||
|
|
||||||
|
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.tokenizer.mask_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = inputs.clone()
|
||||||
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
|
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
||||||
|
special_tokens_mask = [
|
||||||
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
||||||
|
]
|
||||||
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||||
|
if self.tokenizer._pad_token is not None:
|
||||||
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
|
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
||||||
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|
||||||
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
||||||
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||||
|
|
||||||
|
# 10% of the time, we replace masked input tokens with random word
|
||||||
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
||||||
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
||||||
|
inputs[indices_random] = random_words[indices_random]
|
||||||
|
|
||||||
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||||
|
return inputs, labels
|
||||||
|
|||||||
@@ -3,5 +3,5 @@
|
|||||||
# module, but to preserve other warnings. So, don't check this module at all.
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
from .glue import GlueDataset, GlueDataTrainingArguments
|
from .glue import GlueDataset, GlueDataTrainingArguments
|
||||||
from .language_modeling import LineByLineTextDataset, TextDataset
|
from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction
|
||||||
from .squad import SquadDataset, SquadDataTrainingArguments
|
from .squad import SquadDataset, SquadDataTrainingArguments
|
||||||
|
|||||||
@@ -109,3 +109,91 @@ class LineByLineTextDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, i) -> torch.Tensor:
|
def __getitem__(self, i) -> torch.Tensor:
|
||||||
return torch.tensor(self.examples[i], dtype=torch.long)
|
return torch.tensor(self.examples[i], dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
class TextDatasetForNextSentencePrediction(Dataset):
|
||||||
|
"""
|
||||||
|
This will be superseded by a framework-agnostic approach
|
||||||
|
soon.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
file_path: str,
|
||||||
|
block_size: int,
|
||||||
|
overwrite_cache=False,
|
||||||
|
):
|
||||||
|
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
|
||||||
|
|
||||||
|
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)
|
||||||
|
|
||||||
|
directory, filename = os.path.split(file_path)
|
||||||
|
cached_features_file = os.path.join(
|
||||||
|
directory,
|
||||||
|
"cached_nsp_{}_{}_{}".format(
|
||||||
|
tokenizer.__class__.__name__,
|
||||||
|
str(block_size),
|
||||||
|
filename,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.examples = []
|
||||||
|
|
||||||
|
# Make sure only the first process in distributed training processes the dataset,
|
||||||
|
# and the others will use the cache.
|
||||||
|
lock_path = cached_features_file + ".lock"
|
||||||
|
|
||||||
|
# Input file format:
|
||||||
|
# (1) One sentence per line. These should ideally be actual sentences, not
|
||||||
|
# entire paragraphs or arbitrary spans of text. (Because we use the
|
||||||
|
# sentence boundaries for the "next sentence prediction" task).
|
||||||
|
# (2) Blank lines between documents. Document boundaries are needed so
|
||||||
|
# that the "next sentence prediction" task doesn't span between documents.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# I am very happy.
|
||||||
|
# Here is the second sentence.
|
||||||
|
#
|
||||||
|
# A new document.
|
||||||
|
|
||||||
|
with FileLock(lock_path):
|
||||||
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
|
start = time.time()
|
||||||
|
with open(cached_features_file, "rb") as handle:
|
||||||
|
self.examples = pickle.load(handle)
|
||||||
|
logger.info(
|
||||||
|
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Creating features from dataset file at {directory}")
|
||||||
|
|
||||||
|
self.examples = [[]]
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# Empty lines are used as document delimiters
|
||||||
|
if not line and len(self.examples[-1]) != 0:
|
||||||
|
self.examples.append([])
|
||||||
|
tokens = tokenizer.tokenize(line)
|
||||||
|
tokens = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
if tokens:
|
||||||
|
self.examples[-1].append(tokens)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
with open(cached_features_file, "wb") as handle:
|
||||||
|
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
logger.info(
|
||||||
|
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.examples[i]
|
||||||
|
|||||||
@@ -9,11 +9,13 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForNextSentencePrediction,
|
||||||
DataCollatorForPermutationLanguageModeling,
|
DataCollatorForPermutationLanguageModeling,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
TextDataset,
|
TextDataset,
|
||||||
|
TextDatasetForNextSentencePrediction,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,3 +152,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
# Expect error due to odd sequence length
|
# Expect error due to odd sequence length
|
||||||
data_collator(example)
|
data_collator(example)
|
||||||
|
|
||||||
|
def test_nsp(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
|
||||||
|
|
||||||
|
dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||||
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
|
batch = data_collator(examples)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
|
||||||
|
# Since there are randomly generated false samples, the total number of samples is not fixed.
|
||||||
|
total_samples = batch["input_ids"].shape[0]
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
||||||
|
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
||||||
|
|||||||
Reference in New Issue
Block a user