Run mlm pad to multiple for fp16 (#11128)
* Add mlm collator pad to multiple option (#10627) * Use padding to 8x in run mlm (#10627)
This commit is contained in:
@@ -422,7 +422,12 @@ def main():
|
|||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# This one will take care of randomly masking the tokens.
|
# This one will take care of randomly masking the tokens.
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
|
||||||
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
mlm_probability=data_args.mlm_probability,
|
||||||
|
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def _collate_batch(examples, tokenizer):
|
def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
||||||
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
||||||
# Tensorize if necessary.
|
# Tensorize if necessary.
|
||||||
if isinstance(examples[0], (list, tuple)):
|
if isinstance(examples[0], (list, tuple)):
|
||||||
@@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
|
|||||||
# Check if padding is necessary.
|
# Check if padding is necessary.
|
||||||
length_of_first = examples[0].size(0)
|
length_of_first = examples[0].size(0)
|
||||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||||
if are_tensors_same_length:
|
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
||||||
return torch.stack(examples, dim=0)
|
return torch.stack(examples, dim=0)
|
||||||
|
|
||||||
# If yes, check if we have a `pad_token`.
|
# If yes, check if we have a `pad_token`.
|
||||||
@@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
|
|||||||
|
|
||||||
# Creating the full tensor and filling it with our data.
|
# Creating the full tensor and filling it with our data.
|
||||||
max_length = max(x.size(0) for x in examples)
|
max_length = max(x.size(0) for x in examples)
|
||||||
|
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||||
|
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||||
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
||||||
for i, example in enumerate(examples):
|
for i, example in enumerate(examples):
|
||||||
if tokenizer.padding_side == "right":
|
if tokenizer.padding_side == "right":
|
||||||
@@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
|
|||||||
non-masked tokens and the value to predict for the masked token.
|
non-masked tokens and the value to predict for the masked token.
|
||||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
||||||
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
|
|||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
mlm: bool = True
|
mlm: bool = True
|
||||||
mlm_probability: float = 0.15
|
mlm_probability: float = 0.15
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.mlm and self.tokenizer.mask_token is None:
|
if self.mlm and self.tokenizer.mask_token is None:
|
||||||
@@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
|
|||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||||
batch = self.tokenizer.pad(examples, return_tensors="pt")
|
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
||||||
else:
|
else:
|
||||||
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
|
batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
|
||||||
|
|
||||||
# If special token mask has been preprocessed, pop it from the dict.
|
# If special token mask has been preprocessed, pop it from the dict.
|
||||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||||
|
|||||||
@@ -146,11 +146,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||||
|
|
||||||
def test_data_collator_for_language_modeling(self):
|
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
||||||
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
||||||
|
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||||
batch = data_collator(no_pad_features)
|
batch = data_collator(no_pad_features)
|
||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||||
@@ -160,6 +157,15 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8)
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||||
|
|
||||||
tokenizer._pad_token = None
|
tokenizer._pad_token = None
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@@ -185,6 +191,32 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(torch.any(masked_tokens))
|
self.assertTrue(torch.any(masked_tokens))
|
||||||
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||||
|
batch = data_collator(no_pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(torch.any(masked_tokens))
|
||||||
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
batch = data_collator(pad_features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||||
|
|
||||||
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||||
|
self.assertTrue(torch.any(masked_tokens))
|
||||||
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||||
|
|
||||||
|
def test_data_collator_for_language_modeling(self):
|
||||||
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
|
no_pad_features = [list(range(10)), list(range(10))]
|
||||||
|
pad_features = [list(range(5)), list(range(10))]
|
||||||
|
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||||
|
|
||||||
def test_plm(self):
|
def test_plm(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||||
@@ -225,6 +257,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
||||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
||||||
|
|
||||||
def test_sop(self):
|
def test_sop(self):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
features = [
|
features = [
|
||||||
@@ -242,3 +282,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
|
||||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
||||||
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||||
|
batch = data_collator(features)
|
||||||
|
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
||||||
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
||||||
|
|||||||
Reference in New Issue
Block a user