Added data collator for permutation (XLNet) language modeling and related calls (#5522)
* Added data collator for XLNet language modeling and related calls Added DataCollatorForXLNetLanguageModeling in data/data_collator.py to generate necessary inputs for language modeling training with XLNetLMHeadModel. Also added related arguments, logic and calls in examples/language-modeling/run_language_modeling.py. Resolves: #4739, #2008 (partially) * Changed name to `DataCollatorForPermutationLanguageModeling` Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`. Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use. CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of similar to `mems` for XLNet). Changed calls and imports appropriately. * Added detailed comments, changed variable names Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative. * Added tests for new data collator Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences. * Fixed styling issues
This commit is contained in:
@@ -14,9 +14,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
|
||||||
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
|
||||||
using a masked language modeling (MLM) loss.
|
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ from transformers import (
|
|||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForPermutationLanguageModeling,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
@@ -101,6 +102,15 @@ class DataTrainingArguments:
|
|||||||
mlm_probability: float = field(
|
mlm_probability: float = field(
|
||||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||||
)
|
)
|
||||||
|
plm_probability: float = field(
|
||||||
|
default=1 / 6,
|
||||||
|
metadata={
|
||||||
|
"help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_span_length: int = field(
|
||||||
|
default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
|
||||||
|
)
|
||||||
|
|
||||||
block_size: int = field(
|
block_size: int = field(
|
||||||
default=-1,
|
default=-1,
|
||||||
@@ -207,8 +217,8 @@ def main():
|
|||||||
|
|
||||||
if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
|
if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
|
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
|
||||||
"flag (masked language modeling)."
|
"--mlm flag (masked language modeling)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.block_size <= 0:
|
if data_args.block_size <= 0:
|
||||||
@@ -221,9 +231,14 @@ def main():
|
|||||||
|
|
||||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||||
data_collator = DataCollatorForLanguageModeling(
|
if config.model_type == "xlnet":
|
||||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||||
)
|
tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
@@ -400,7 +400,12 @@ if is_torch_available():
|
|||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
from .trainer import Trainer, torch_distributed_zero_first
|
from .trainer import Trainer, torch_distributed_zero_first
|
||||||
from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling
|
from .data.data_collator import (
|
||||||
|
default_data_collator,
|
||||||
|
DataCollator,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForPermutationLanguageModeling,
|
||||||
|
)
|
||||||
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
|
||||||
|
|
||||||
# Benchmarks
|
# Benchmarks
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten
|
|||||||
Very simple data collator that:
|
Very simple data collator that:
|
||||||
- simply collates batches of dict-like objects
|
- simply collates batches of dict-like objects
|
||||||
- Performs special handling for potential keys named:
|
- Performs special handling for potential keys named:
|
||||||
- `label`: handles a single value (int or float) per object
|
- ``label``: handles a single value (int or float) per object
|
||||||
- `label_ids`: handles a list of values per object
|
- ``label_ids``: handles a list of values per object
|
||||||
- does not do any additional preprocessing
|
- does not do any additional preprocessing
|
||||||
|
|
||||||
i.e., Property names of the input object will be used as corresponding inputs to the model.
|
i.e., Property names of the input object will be used as corresponding inputs to the model.
|
||||||
@@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling:
|
|||||||
|
|
||||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForPermutationLanguageModeling:
|
||||||
|
"""
|
||||||
|
Data collator used for permutation language modeling.
|
||||||
|
- collates batches of tensors, honoring their tokenizer's pad_token
|
||||||
|
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizer
|
||||||
|
plm_probability: float = 1 / 6
|
||||||
|
max_span_length: int = 5 # maximum length of a span of masked tokens
|
||||||
|
|
||||||
|
def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
batch = self._tensorize_batch(examples)
|
||||||
|
inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
|
||||||
|
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
||||||
|
|
||||||
|
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
length_of_first = examples[0].size(0)
|
||||||
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||||
|
if are_tensors_same_length:
|
||||||
|
return torch.stack(examples, dim=0)
|
||||||
|
else:
|
||||||
|
if self.tokenizer._pad_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to pad samples but the tokenizer you are using"
|
||||||
|
f" ({self.tokenizer.__class__.__name__}) does not have one."
|
||||||
|
)
|
||||||
|
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
||||||
|
0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
|
||||||
|
1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
|
||||||
|
2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked
|
||||||
|
3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length``
|
||||||
|
4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.tokenizer.mask_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs.size(1) % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = inputs.clone()
|
||||||
|
# Creating the mask and target_mapping tensors
|
||||||
|
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
||||||
|
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||||
|
|
||||||
|
for i in range(labels.size(0)):
|
||||||
|
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
||||||
|
cur_len = 0
|
||||||
|
max_len = labels.size(1)
|
||||||
|
|
||||||
|
while cur_len < max_len:
|
||||||
|
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
||||||
|
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
||||||
|
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
||||||
|
context_length = int(span_length / self.plm_probability)
|
||||||
|
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
||||||
|
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
||||||
|
masked_indices[i, start_index : start_index + span_length] = 1
|
||||||
|
# Set `cur_len = cur_len + context_length`
|
||||||
|
cur_len += context_length
|
||||||
|
|
||||||
|
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
||||||
|
# the i-th predict corresponds to the i-th token.
|
||||||
|
target_mapping[i] = torch.eye(labels.size(1))
|
||||||
|
|
||||||
|
special_tokens_mask = torch.tensor(
|
||||||
|
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
||||||
|
dtype=torch.bool,
|
||||||
|
)
|
||||||
|
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
||||||
|
if self.tokenizer._pad_token is not None:
|
||||||
|
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
||||||
|
masked_indices.masked_fill_(padding_mask, value=0.0)
|
||||||
|
|
||||||
|
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
||||||
|
non_func_mask = ~(padding_mask & special_tokens_mask)
|
||||||
|
|
||||||
|
inputs[masked_indices] = self.tokenizer.mask_token_id
|
||||||
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|
||||||
|
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
||||||
|
|
||||||
|
for i in range(labels.size(0)):
|
||||||
|
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
||||||
|
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
||||||
|
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
||||||
|
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
||||||
|
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
||||||
|
# This requires that the sequence length be even.
|
||||||
|
|
||||||
|
# Create a linear factorisation order
|
||||||
|
perm_index = torch.arange(labels.size(1))
|
||||||
|
# Split this into two halves, assuming that half the sequence is reused each time
|
||||||
|
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
||||||
|
# Permute the two halves such that they do not cross over
|
||||||
|
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
||||||
|
# Flatten this out into the desired permuted factorisation order
|
||||||
|
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
||||||
|
# Set the permutation indices of non-masked (non-functional) tokens to the
|
||||||
|
# smallest index (-1) so that:
|
||||||
|
# (1) They can be seen by all other positions
|
||||||
|
# (2) They cannot see masked positions, so there won't be information leak
|
||||||
|
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
||||||
|
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
||||||
|
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
||||||
|
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
||||||
|
perm_mask[i] = (
|
||||||
|
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
||||||
|
) & masked_indices[i]
|
||||||
|
|
||||||
|
return inputs, perm_mask, target_mapping, labels
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ if is_torch_available():
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
|
DataCollatorForPermutationLanguageModeling,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
TextDataset,
|
TextDataset,
|
||||||
@@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||||
|
|
||||||
|
def test_plm(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
|
||||||
|
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
||||||
|
# ^ permutation lm
|
||||||
|
|
||||||
|
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||||
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
|
batch = data_collator(examples)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
|
||||||
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
|
||||||
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
|
||||||
|
|
||||||
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||||
|
examples = [dataset[i] for i in range(len(dataset))]
|
||||||
|
batch = data_collator(examples)
|
||||||
|
self.assertIsInstance(batch, dict)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||||
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
|
||||||
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||||
|
|
||||||
|
example = [torch.randint(5, [5])]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# Expect error due to odd sequence length
|
||||||
|
data_collator(example)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TrainerIntegrationTest(unittest.TestCase):
|
class TrainerIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user