Run mlm pad to multiple for fp16 (#11128)

* Add mlm collator pad to multiple option (#10627)

* Use padding to 8x in run mlm (#10627)
This commit is contained in:
Andrea Cappelli
2021-04-08 22:12:49 +02:00
committed by GitHub
parent dfed4ec263
commit 6c40e49712
3 changed files with 67 additions and 9 deletions

View File

@@ -422,7 +422,12 @@ def main():
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=data_args.mlm_probability,
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
)
# Initialize our Trainer
trainer = Trainer(