deprecate is_torch_bf16_available (#17738)

* deprecate is_torch_bf16_available

* address suggestions
This commit is contained in:
Stas Bekman
2022-06-20 05:40:11 -07:00
committed by GitHub
parent 132402d752
commit a2d34b7c04
5 changed files with 47 additions and 19 deletions

View File

@@ -67,7 +67,8 @@ from .utils import (
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
@@ -486,11 +487,19 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_bf16(test_case):
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
def require_torch_bf16_gpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
return unittest.skipUnless(
is_torch_bf16_available(),
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
is_torch_bf16_gpu_available(),
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
)(test_case)
def require_torch_bf16_cpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using CPU."""
return unittest.skipUnless(
is_torch_bf16_cpu_available(),
"test requires torch>=1.10, using CPU",
)(test_case)