Improve pytorch examples for fp16 (#9796)
* Pad to 8x for fp16 multiple choice example (#9752) * Pad to 8x for fp16 squad trainer example (#9752) * Pad to 8x for fp16 ner example (#9752) * Pad to 8x for fp16 swag example (#9752) * Pad to 8x for fp16 qa beam search example (#9752) * Pad to 8x for fp16 qa example (#9752) * Pad to 8x for fp16 seq2seq example (#9752) * Pad to 8x for fp16 glue example (#9752) * Pad to 8x for fp16 new ner example (#9752) * update script template #9752 * Update examples/multiple-choice/run_swag.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa_beam_search.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve code quality #9752 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
@@ -188,6 +189,9 @@ def main():
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
@@ -195,6 +199,7 @@ def main():
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
Reference in New Issue
Block a user