Refine Bf16 test for deepspeed (#17734)
* Refine BF16 check in CPU/GPU * Fixes * Renames
This commit is contained in:
@@ -42,7 +42,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_available
|
||||
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -129,7 +129,7 @@ FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
stages = [ZERO2, ZERO3]
|
||||
if is_torch_bf16_available():
|
||||
if is_torch_bf16_gpu_available():
|
||||
dtypes = [FP16, BF16]
|
||||
else:
|
||||
dtypes = [FP16]
|
||||
@@ -920,7 +920,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(["bf16", "fp16", "fp32"])
|
||||
def test_inference(self, dtype):
|
||||
if dtype == "bf16" and not is_torch_bf16_available():
|
||||
if dtype == "bf16" and not is_torch_bf16_gpu_available():
|
||||
self.skipTest("test requires bfloat16 hardware support")
|
||||
|
||||
# this is just inference, so no optimizer should be loaded
|
||||
|
||||
Reference in New Issue
Block a user