[trainer] add tf32-mode control (#14606)

* [trainer] add --tf32 support

* it's pt>=.17

* it's pt>=.17

* flip the default to True

* add experimental note

* simplify logic

* style

* switch to 3-state logic

* doc

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* re-style code

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-12-03 10:08:58 -08:00
committed by GitHub
parent aada989ad5
commit 71b1bf7ea8
5 changed files with 92 additions and 29 deletions

View File

@@ -57,6 +57,7 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
slow,
)
@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
# will add more specific tests once there are some bugs to fix
@require_torch_gpu
@require_torch_tf32
def test_tf32(self):
# very basic test
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
trainer.train()
self.check_trained_model(trainer.model)
@require_torch
@require_sentencepiece