Added data collator for permutation (XLNet) language modeling and related calls (#5522)
* Added data collator for XLNet language modeling and related calls Added DataCollatorForXLNetLanguageModeling in data/data_collator.py to generate necessary inputs for language modeling training with XLNetLMHeadModel. Also added related arguments, logic and calls in examples/language-modeling/run_language_modeling.py. Resolves: #4739, #2008 (partially) * Changed name to `DataCollatorForPermutationLanguageModeling` Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`. Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use. CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of similar to `mems` for XLNet). Changed calls and imports appropriately. * Added detailed comments, changed variable names Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative. * Added tests for new data collator Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences. * Fixed styling issues
This commit is contained in:
@@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
||||
Very simple data collator that:
|
||||
- simply collates batches of dict-like objects
|
||||
- Performs special handling for potential keys named:
|
||||
- `label`: handles a single value (int or float) per object
|
||||
- `label_ids`: handles a list of values per object
|
||||
- ``label``: handles a single value (int or float) per object
|
||||
- ``label_ids``: handles a list of values per object
|
||||
- does not do any additional preprocessing
|
||||
|
||||
i.e., Property names of the input object will be used as corresponding inputs to the model.
|
||||
@@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling:
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPermutationLanguageModeling:
|
||||
"""
|
||||
Data collator used for permutation language modeling.
|
||||
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
plm_probability: float = 1 / 6
|
||||
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||
|
||||
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = self._tensorize_batch(examples)
|
||||
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||
|
||||
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||
length_of_first = examples[0].size(0)
|
||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||
if are_tensors_same_length:
|
||||
return torch.stack(examples, dim=0)
|
||||
else:
|
||||
if self.tokenizer._pad_token is None:
|
||||
raise ValueError(
|
||||
"You are attempting to pad samples but the tokenizer you are using"
|
||||
f" ({self.tokenizer.__class__.__name__}) does not have one."
|
||||
)
|
||||
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
||||
|
||||
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
||||
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
|
||||
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
|
||||
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked
|
||||
3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length``
|
||||
4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1.
|
||||
"""
|
||||
|
||||
if self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
|
||||
)
|
||||
|
||||
if inputs.size(1) % 2 != 0:
|
||||
raise ValueError(
|
||||
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
|
||||
)
|
||||
|
||||
labels = inputs.clone()
|
||||
# Creating the mask and target_mapping tensors
|
||||
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
||||
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||
|
||||
for i in range(labels.size(0)):
|
||||
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
||||
cur_len = 0
|
||||
max_len = labels.size(1)
|
||||
|
||||
while cur_len < max_len:
|
||||
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
||||
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
||||
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
||||
context_length = int(span_length / self.plm_probability)
|
||||
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
||||
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
||||
masked_indices[i, start_index : start_index + span_length] = 1
|
||||
# Set `cur_len = cur_len + context_length`
|
||||
cur_len += context_length
|
||||
|
||||
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
||||
# the i-th predict corresponds to the i-th token.
|
||||
target_mapping[i] = torch.eye(labels.size(1))
|
||||
|
||||
special_tokens_mask = torch.tensor(
|
||||
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
||||
dtype=torch.bool,
|
||||
)
|
||||
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
||||
if self.tokenizer._pad_token is not None:
|
||||
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||
|
||||
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
||||
non_func_mask = ~(padding_mask & special_tokens_mask)
|
||||
|
||||
inputs[masked_indices] = self.tokenizer.mask_token_id
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||
|
||||
for i in range(labels.size(0)):
|
||||
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
||||
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
||||
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
||||
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
||||
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
||||
# This requires that the sequence length be even.
|
||||
|
||||
# Create a linear factorisation order
|
||||
perm_index = torch.arange(labels.size(1))
|
||||
# Split this into two halves, assuming that half the sequence is reused each time
|
||||
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
||||
# Permute the two halves such that they do not cross over
|
||||
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
||||
# Flatten this out into the desired permuted factorisation order
|
||||
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
||||
# Set the permutation indices of non-masked (non-functional) tokens to the
|
||||
# smallest index (-1) so that:
|
||||
# (1) They can be seen by all other positions
|
||||
# (2) They cannot see masked positions, so there won't be information leak
|
||||
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
||||
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
||||
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
||||
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
||||
perm_mask[i] = (
|
||||
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
||||
) & masked_indices[i]
|
||||
|
||||
return inputs, perm_mask, target_mapping, labels
|
||||
|
||||
Reference in New Issue
Block a user