* MOD: fit chinese wwm to new datasets * MOD: move wwm to new folder * MOD: formate code * Styling * MOD add param and recover trainer Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
689 lines
34 KiB
Python
689 lines
34 KiB
Python
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import random
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
from ..tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase
|
|
|
|
|
|
InputDataClass = NewType("InputDataClass", Any)
|
|
|
|
"""
|
|
A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
|
|
of Tensors.
|
|
"""
|
|
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
|
|
|
|
|
|
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
|
potential keys named:
|
|
|
|
- ``label``: handles a single value (int or float) per object
|
|
- ``label_ids``: handles a list of values per object
|
|
|
|
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
|
to the model. See glue and ner for example of how it's useful.
|
|
"""
|
|
|
|
# In this function we'll make the assumption that all `features` in the batch
|
|
# have the same attributes.
|
|
# So we will look at the first element as a proxy for what attributes exist
|
|
# on the whole batch.
|
|
if not isinstance(features[0], (dict, BatchEncoding)):
|
|
features = [vars(f) for f in features]
|
|
|
|
first = features[0]
|
|
batch = {}
|
|
|
|
# Special handling for labels.
|
|
# Ensure that tensor is created with the correct type
|
|
# (it should be automatically the case, but let's make sure of it.)
|
|
if "label" in first and first["label"] is not None:
|
|
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
|
dtype = torch.long if isinstance(label, int) else torch.float
|
|
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
|
elif "label_ids" in first and first["label_ids"] is not None:
|
|
if isinstance(first["label_ids"], torch.Tensor):
|
|
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
|
else:
|
|
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
|
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
|
|
|
# Handling of all other possible keys.
|
|
# Again, we will use the first element to figure out which key/values are not None for this model.
|
|
for k, v in first.items():
|
|
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
|
if isinstance(v, torch.Tensor):
|
|
batch[k] = torch.stack([f[k] for f in features])
|
|
else:
|
|
batch[k] = torch.tensor([f[k] for f in features])
|
|
|
|
return batch
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorWithPadding:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received.
|
|
|
|
Args:
|
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
|
The tokenizer used for encoding the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence if provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
|
batch = self.tokenizer.pad(
|
|
features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
if "label" in batch:
|
|
batch["labels"] = batch["label"]
|
|
del batch["label"]
|
|
if "label_ids" in batch:
|
|
batch["labels"] = batch["label_ids"]
|
|
del batch["label_ids"]
|
|
return batch
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForTokenClassification:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received, as well as the labels.
|
|
|
|
Args:
|
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
|
The tokenizer used for encoding the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence if provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
|
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
label_pad_token_id: int = -100
|
|
|
|
def __call__(self, features):
|
|
label_name = "label" if "label" in features[0].keys() else "labels"
|
|
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
|
batch = self.tokenizer.pad(
|
|
features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
|
return_tensors="pt" if labels is None else None,
|
|
)
|
|
|
|
if labels is None:
|
|
return batch
|
|
|
|
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
|
|
padding_side = self.tokenizer.padding_side
|
|
if padding_side == "right":
|
|
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
|
|
else:
|
|
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
|
|
|
|
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
|
return batch
|
|
|
|
|
|
def _collate_batch(examples, tokenizer):
|
|
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
|
# Tensorize if necessary.
|
|
if isinstance(examples[0], (list, tuple)):
|
|
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
|
|
|
# Check if padding is necessary.
|
|
length_of_first = examples[0].size(0)
|
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
|
if are_tensors_same_length:
|
|
return torch.stack(examples, dim=0)
|
|
|
|
# If yes, check if we have a `pad_token`.
|
|
if tokenizer._pad_token is None:
|
|
raise ValueError(
|
|
"You are attempting to pad samples but the tokenizer you are using"
|
|
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
|
)
|
|
|
|
# Creating the full tensor and filling it with our data.
|
|
max_length = max(x.size(0) for x in examples)
|
|
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
|
for i, example in enumerate(examples):
|
|
if tokenizer.padding_side == "right":
|
|
result[i, : example.shape[0]] = example
|
|
else:
|
|
result[i, -example.shape[0] :] = example
|
|
return result
|
|
|
|
|
|
def tolist(x: Union[List[Any], torch.Tensor]):
|
|
return x.tolist() if isinstance(x, torch.Tensor) else x
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForSeq2Seq:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received, as well as the labels.
|
|
|
|
Args:
|
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
|
The tokenizer used for encoding the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence is provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
padding: Union[bool, str, PaddingStrategy] = True
|
|
max_length: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
label_pad_token_id: int = -100
|
|
|
|
def __call__(self, features):
|
|
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
|
|
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
|
# same length to return tensors.
|
|
if labels is not None:
|
|
max_label_length = max(len(l) for l in labels)
|
|
padding_side = self.tokenizer.padding_side
|
|
for feature in features:
|
|
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
|
|
feature["labels"] = (
|
|
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
|
)
|
|
|
|
return self.tokenizer.pad(
|
|
features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForLanguageModeling:
|
|
"""
|
|
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
|
are not all of the same length.
|
|
|
|
Args:
|
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
|
The tokenizer used for encoding the data.
|
|
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
|
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
|
|
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
|
|
non-masked tokens and the value to predict for the masked token.
|
|
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
|
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
|
|
|
.. note::
|
|
|
|
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
|
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
|
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
|
argument :obj:`return_special_tokens_mask=True`.
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
mlm: bool = True
|
|
mlm_probability: float = 0.15
|
|
|
|
def __post_init__(self):
|
|
if self.mlm and self.tokenizer.mask_token is None:
|
|
raise ValueError(
|
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
|
"You should pass `mlm=False` to train on causal language modeling instead."
|
|
)
|
|
|
|
def __call__(
|
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
|
) -> Dict[str, torch.Tensor]:
|
|
# Handle dict or lists with proper padding and conversion to tensor.
|
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
|
batch = self.tokenizer.pad(examples, return_tensors="pt")
|
|
else:
|
|
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
|
|
|
|
# If special token mask has been preprocessed, pop it from the dict.
|
|
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
|
if self.mlm:
|
|
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
|
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
|
)
|
|
else:
|
|
labels = batch["input_ids"].clone()
|
|
if self.tokenizer.pad_token_id is not None:
|
|
labels[labels == self.tokenizer.pad_token_id] = -100
|
|
batch["labels"] = labels
|
|
return batch
|
|
|
|
def mask_tokens(
|
|
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
|
"""
|
|
labels = inputs.clone()
|
|
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
|
if special_tokens_mask is None:
|
|
special_tokens_mask = [
|
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
|
]
|
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
|
else:
|
|
special_tokens_mask = special_tokens_mask.bool()
|
|
|
|
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
|
inputs[indices_random] = random_words[indices_random]
|
|
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
return inputs, labels
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
|
"""
|
|
Data collator used for language modeling.
|
|
|
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
|
- preprocesses batches for masked language modeling
|
|
"""
|
|
|
|
def __call__(
|
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
|
) -> Dict[str, torch.Tensor]:
|
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
|
input_ids = [e["input_ids"] for e in examples]
|
|
else:
|
|
input_ids = examples
|
|
examples = [{"input_ids": e} for e in examples]
|
|
|
|
batch_input = _collate_batch(input_ids, self.tokenizer)
|
|
|
|
mask_labels = []
|
|
for e in examples:
|
|
ref_tokens = []
|
|
for id in tolist(e["input_ids"]):
|
|
token = self.tokenizer._convert_id_to_token(id)
|
|
ref_tokens.append(token)
|
|
|
|
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
|
if "chinese_ref" in e:
|
|
ref_pos = tolist(e["chinese_ref"])
|
|
len_seq = len(e["input_ids"])
|
|
for i in range(len_seq):
|
|
if i in ref_pos:
|
|
ref_tokens[i] = "##" + ref_tokens[i]
|
|
mask_labels.append(self._whole_word_mask(ref_tokens))
|
|
batch_mask = _collate_batch(mask_labels, self.tokenizer)
|
|
inputs, labels = self.mask_tokens(batch_input, batch_mask)
|
|
return {"input_ids": inputs, "labels": labels}
|
|
|
|
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
|
"""
|
|
Get 0/1 labels for masked tokens with whole word mask proxy
|
|
"""
|
|
|
|
cand_indexes = []
|
|
for (i, token) in enumerate(input_tokens):
|
|
if token == "[CLS]" or token == "[SEP]":
|
|
continue
|
|
|
|
if len(cand_indexes) >= 1 and token.startswith("##"):
|
|
cand_indexes[-1].append(i)
|
|
else:
|
|
cand_indexes.append([i])
|
|
|
|
random.shuffle(cand_indexes)
|
|
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
|
masked_lms = []
|
|
covered_indexes = set()
|
|
for index_set in cand_indexes:
|
|
if len(masked_lms) >= num_to_predict:
|
|
break
|
|
# If adding a whole-word mask would exceed the maximum number of
|
|
# predictions, then just skip this candidate.
|
|
if len(masked_lms) + len(index_set) > num_to_predict:
|
|
continue
|
|
is_any_index_covered = False
|
|
for index in index_set:
|
|
if index in covered_indexes:
|
|
is_any_index_covered = True
|
|
break
|
|
if is_any_index_covered:
|
|
continue
|
|
for index in index_set:
|
|
covered_indexes.add(index)
|
|
masked_lms.append(index)
|
|
|
|
assert len(covered_indexes) == len(masked_lms)
|
|
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
|
return mask_labels
|
|
|
|
def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
|
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
|
"""
|
|
|
|
if self.tokenizer.mask_token is None:
|
|
raise ValueError(
|
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
|
)
|
|
labels = inputs.clone()
|
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
|
|
|
probability_matrix = mask_labels
|
|
|
|
special_tokens_mask = [
|
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
|
]
|
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
|
if self.tokenizer._pad_token is not None:
|
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
|
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
|
|
|
masked_indices = probability_matrix.bool()
|
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
|
inputs[indices_random] = random_words[indices_random]
|
|
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
return inputs, labels
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
|
"""
|
|
Data collator used for sentence order prediction task.
|
|
|
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
|
- preprocesses batches for both masked language modeling and sentence order prediction
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
warnings.warn(
|
|
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
|
|
"DataCollatorForLanguageModeling instead.",
|
|
FutureWarning,
|
|
)
|
|
|
|
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
input_ids = [example["input_ids"] for example in examples]
|
|
input_ids = _collate_batch(input_ids, self.tokenizer)
|
|
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
|
|
|
|
token_type_ids = [example["token_type_ids"] for example in examples]
|
|
# size of segment_ids varied because randomness, padding zero to the end as the original implementation
|
|
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
|
|
|
sop_label_list = [example["sentence_order_label"] for example in examples]
|
|
sentence_order_label = torch.stack(sop_label_list)
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"labels": labels,
|
|
"attention_mask": attention_mask,
|
|
"token_type_ids": token_type_ids,
|
|
"sentence_order_label": sentence_order_label,
|
|
}
|
|
|
|
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
|
|
original. N-gram not applied yet.
|
|
"""
|
|
if self.tokenizer.mask_token is None:
|
|
raise ValueError(
|
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
|
)
|
|
|
|
labels = inputs.clone()
|
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
|
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
|
special_tokens_mask = [
|
|
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
|
]
|
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
|
if self.tokenizer._pad_token is not None:
|
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
|
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
|
attention_mask = (~masked_indices).float()
|
|
if self.tokenizer._pad_token is not None:
|
|
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
|
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
|
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
|
inputs[indices_random] = random_words[indices_random]
|
|
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
return inputs, labels, attention_mask
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForPermutationLanguageModeling:
|
|
"""
|
|
Data collator used for permutation language modeling.
|
|
|
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
|
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
plm_probability: float = 1 / 6
|
|
max_span_length: int = 5 # maximum length of a span of masked tokens
|
|
|
|
def __call__(
|
|
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
|
|
) -> Dict[str, torch.Tensor]:
|
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
|
examples = [e["input_ids"] for e in examples]
|
|
batch = _collate_batch(examples, self.tokenizer)
|
|
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
|
|
|
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
|
|
|
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
|
|
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be
|
|
masked)
|
|
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be
|
|
masked
|
|
3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length -
|
|
span_length]`` and mask tokens ``start_index:start_index + span_length``
|
|
4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in
|
|
the sequence to be processed), repeat from Step 1.
|
|
"""
|
|
|
|
if self.tokenizer.mask_token is None:
|
|
raise ValueError(
|
|
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
|
|
)
|
|
|
|
if inputs.size(1) % 2 != 0:
|
|
raise ValueError(
|
|
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
|
|
)
|
|
|
|
labels = inputs.clone()
|
|
# Creating the mask and target_mapping tensors
|
|
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
|
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
|
|
|
for i in range(labels.size(0)):
|
|
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
|
cur_len = 0
|
|
max_len = labels.size(1)
|
|
|
|
while cur_len < max_len:
|
|
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
|
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
|
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
|
context_length = int(span_length / self.plm_probability)
|
|
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
|
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
|
masked_indices[i, start_index : start_index + span_length] = 1
|
|
# Set `cur_len = cur_len + context_length`
|
|
cur_len += context_length
|
|
|
|
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
|
# the i-th predict corresponds to the i-th token.
|
|
target_mapping[i] = torch.eye(labels.size(1))
|
|
|
|
special_tokens_mask = torch.tensor(
|
|
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
|
dtype=torch.bool,
|
|
)
|
|
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
|
if self.tokenizer._pad_token is not None:
|
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
|
masked_indices.masked_fill_(padding_mask, value=0.0)
|
|
|
|
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
|
non_func_mask = ~(padding_mask | special_tokens_mask)
|
|
|
|
inputs[masked_indices] = self.tokenizer.mask_token_id
|
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
|
|
|
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
|
|
|
for i in range(labels.size(0)):
|
|
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
|
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
|
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
|
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
|
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
|
# This requires that the sequence length be even.
|
|
|
|
# Create a linear factorisation order
|
|
perm_index = torch.arange(labels.size(1))
|
|
# Split this into two halves, assuming that half the sequence is reused each time
|
|
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
|
# Permute the two halves such that they do not cross over
|
|
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
|
# Flatten this out into the desired permuted factorisation order
|
|
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
|
# Set the permutation indices of non-masked (non-functional) tokens to the
|
|
# smallest index (-1) so that:
|
|
# (1) They can be seen by all other positions
|
|
# (2) They cannot see masked positions, so there won't be information leak
|
|
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
|
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
|
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
|
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
|
perm_mask[i] = (
|
|
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
|
) & masked_indices[i]
|
|
|
|
return inputs.long(), perm_mask, target_mapping, labels.long()
|