Improve pytorch examples for fp16 (#9796)
* Pad to 8x for fp16 multiple choice example (#9752) * Pad to 8x for fp16 squad trainer example (#9752) * Pad to 8x for fp16 ner example (#9752) * Pad to 8x for fp16 swag example (#9752) * Pad to 8x for fp16 qa beam search example (#9752) * Pad to 8x for fp16 qa example (#9752) * Pad to 8x for fp16 seq2seq example (#9752) * Pad to 8x for fp16 glue example (#9752) * Pad to 8x for fp16 new ner example (#9752) * update script template #9752 * Update examples/multiple-choice/run_swag.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa_beam_search.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve code quality #9752 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from transformers import (
|
||||
AutoConfig,
|
||||
{{cookiecutter.model_class}},
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
@@ -323,7 +324,7 @@ def main():
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator=default_data_collator
|
||||
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
||||
Reference in New Issue
Block a user