From 6c40e49712972141c3d7aeba4ed90bb79f2bb078 Mon Sep 17 00:00:00 2001 From: Andrea Cappelli Date: Thu, 8 Apr 2021 22:12:49 +0200 Subject: [PATCH] Run mlm pad to multiple for fp16 (#11128) * Add mlm collator pad to multiple option (#10627) * Use padding to 8x in run mlm (#10627) --- examples/language-modeling/run_mlm.py | 7 +++- src/transformers/data/data_collator.py | 13 ++++-- tests/test_data_collator.py | 56 ++++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 4fd3c4f217..2934fb0c23 100755 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -422,7 +422,12 @@ def main(): # Data collator # This one will take care of randomly masking the tokens. - data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) + pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm_probability=data_args.mlm_probability, + pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, + ) # Initialize our Trainer trainer = Trainer( diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 94eaade7b1..9915eb5a5f 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -192,7 +192,7 @@ class DataCollatorForTokenClassification: return batch -def _collate_batch(examples, tokenizer): +def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" # Tensorize if necessary. if isinstance(examples[0], (list, tuple)): @@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer): # Check if padding is necessary. length_of_first = examples[0].size(0) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) - if are_tensors_same_length: + if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): return torch.stack(examples, dim=0) # If yes, check if we have a `pad_token`. @@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer): # Creating the full tensor and filling it with our data. max_length = max(x.size(0) for x in examples) + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) for i, example in enumerate(examples): if tokenizer.padding_side == "right": @@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling: non-masked tokens and the value to predict for the masked token. mlm_probability (:obj:`float`, `optional`, defaults to 0.15): The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. .. note:: @@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling: tokenizer: PreTrainedTokenizerBase mlm: bool = True mlm_probability: float = 0.15 + pad_to_multiple_of: Optional[int] = None def __post_init__(self): if self.mlm and self.tokenizer.mask_token is None: @@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling: ) -> Dict[str, torch.Tensor]: # Handle dict or lists with proper padding and conversion to tensor. if isinstance(examples[0], (dict, BatchEncoding)): - batch = self.tokenizer.pad(examples, return_tensors="pt") + batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of) else: - batch = {"input_ids": _collate_batch(examples, self.tokenizer)} + batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)} # If special token mask has been preprocessed, pop it from the dict. special_tokens_mask = batch.pop("special_tokens_mask", None) diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index be138314d3..e9d363229f 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -146,11 +146,8 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3) - def test_data_collator_for_language_modeling(self): + def _test_no_pad_and_pad(self, no_pad_features, pad_features): tokenizer = BertTokenizer(self.vocab_file) - no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] - pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] - data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) batch = data_collator(no_pad_features) self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) @@ -160,6 +157,15 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10))) self.assertEqual(batch["labels"].shape, torch.Size((2, 10))) + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8) + batch = data_collator(no_pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) + + batch = data_collator(pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) + tokenizer._pad_token = None data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) with self.assertRaises(ValueError): @@ -185,6 +191,32 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertTrue(torch.any(masked_tokens)) self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) + data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) + batch = data_collator(no_pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) + + masked_tokens = batch["input_ids"] == tokenizer.mask_token_id + self.assertTrue(torch.any(masked_tokens)) + self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) + + batch = data_collator(pad_features) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 16))) + + masked_tokens = batch["input_ids"] == tokenizer.mask_token_id + self.assertTrue(torch.any(masked_tokens)) + self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist())) + + def test_data_collator_for_language_modeling(self): + no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + self._test_no_pad_and_pad(no_pad_features, pad_features) + + no_pad_features = [list(range(10)), list(range(10))] + pad_features = [list(range(5)), list(range(10))] + self._test_no_pad_and_pad(no_pad_features, pad_features) + def test_plm(self): tokenizer = BertTokenizer(self.vocab_file) no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] @@ -225,6 +257,14 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) + data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) + batch = data_collator(features) + + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 8))) + self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,))) + def test_sop(self): tokenizer = BertTokenizer(self.vocab_file) features = [ @@ -242,3 +282,11 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5))) self.assertEqual(batch["labels"].shape, torch.Size((2, 5))) self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,))) + + data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) + batch = data_collator(features) + + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8))) + self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 8))) + self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))