Improve pytorch examples for fp16 (#9796)
* Pad to 8x for fp16 multiple choice example (#9752) * Pad to 8x for fp16 squad trainer example (#9752) * Pad to 8x for fp16 ner example (#9752) * Pad to 8x for fp16 swag example (#9752) * Pad to 8x for fp16 qa beam search example (#9752) * Pad to 8x for fp16 qa example (#9752) * Pad to 8x for fp16 seq2seq example (#9752) * Pad to 8x for fp16 glue example (#9752) * Pad to 8x for fp16 new ner example (#9752) * update script template #9752 * Update examples/multiple-choice/run_swag.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa_beam_search.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve code quality #9752 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -437,7 +437,11 @@ def main():
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=label_pad_token_id)
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
# Metric
|
||||
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
||||
|
||||
Reference in New Issue
Block a user