[trainer] add tf32-mode control (#14606)
* [trainer] add --tf32 support * it's pt>=.17 * it's pt>=.17 * flip the default to True * add experimental note * simplify logic * style * switch to 3-state logic * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * re-style code Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,7 @@ from .file_utils import (
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_vision_available,
|
||||
@@ -495,9 +496,17 @@ def require_torch_gpu(test_case):
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
||||
if not is_torch_bf16_available():
|
||||
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
|
||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tf32(test_case):
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
||||
if not is_torch_tf32_available():
|
||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
Reference in New Issue
Block a user