[trainer] add tf32-mode control (#14606)

* [trainer] add --tf32 support

* it's pt>=.17

* it's pt>=.17

* flip the default to True

* add experimental note

* simplify logic

* style

* switch to 3-state logic

* doc

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* re-style code

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-12-03 10:08:58 -08:00
committed by GitHub
parent aada989ad5
commit 71b1bf7ea8
5 changed files with 92 additions and 29 deletions

View File

@@ -50,6 +50,7 @@ from .file_utils import (
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
@@ -495,9 +496,17 @@ def require_torch_gpu(test_case):
def require_torch_bf16(test_case):
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
if not is_torch_bf16_available():
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
else:
return test_case
def require_torch_tf32(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
if not is_torch_tf32_available():
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
else:
return test_case