Improve pytorch examples for fp16 (#9796)

* Pad to 8x for fp16 multiple choice example (#9752)

* Pad to 8x for fp16 squad trainer example (#9752)

* Pad to 8x for fp16 ner example (#9752)

* Pad to 8x for fp16 swag example (#9752)

* Pad to 8x for fp16 qa beam search example (#9752)

* Pad to 8x for fp16 qa example (#9752)

* Pad to 8x for fp16 seq2seq example (#9752)

* Pad to 8x for fp16 glue example (#9752)

* Pad to 8x for fp16 new ner example (#9752)

* update script template #9752

* Update examples/multiple-choice/run_swag.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/question-answering/run_qa.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/question-answering/run_qa_beam_search.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve code quality #9752

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Andrea Cappelli
2021-01-26 10:47:07 +01:00
committed by GitHub
parent 781e4b1384
commit 10e5f28212
10 changed files with 53 additions and 9 deletions

View File

@@ -33,6 +33,7 @@ from transformers import (
AutoConfig,
{{cookiecutter.model_class}},
AutoTokenizer,
DataCollatorWithPadding,
HfArgumentParser,
Trainer,
TrainingArguments,
@@ -323,7 +324,7 @@ def main():
)
# Data collator
data_collator=default_data_collator
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Initialize our Trainer
trainer = Trainer(