deprecate is_torch_bf16_available (#17738)
* deprecate is_torch_bf16_available * address suggestions
This commit is contained in:
@@ -67,7 +67,8 @@ from .utils import (
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
@@ -486,11 +487,19 @@ def require_torch_gpu(test_case):
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
|
||||
def require_torch_bf16_gpu(test_case):
|
||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
|
||||
return unittest.skipUnless(
|
||||
is_torch_bf16_available(),
|
||||
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
|
||||
is_torch_bf16_gpu_available(),
|
||||
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_bf16_cpu(test_case):
|
||||
"""Decorator marking a test that requires torch>=1.10, using CPU."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_bf16_cpu_available(),
|
||||
"test requires torch>=1.10, using CPU",
|
||||
)(test_case)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user