[trainer] add tf32-mode control (#14606)
* [trainer] add --tf32 support * it's pt>=.17 * it's pt>=.17 * flip the default to True * add experimental note * simplify logic * style * switch to 3-state logic * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * re-style code Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -57,6 +57,7 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
slow,
|
||||
)
|
||||
@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# will add more specific tests once there are some bugs to fix
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_tf32
|
||||
def test_tf32(self):
|
||||
|
||||
# very basic test
|
||||
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user