Refine Bf16 test for deepspeed (#17734)

* Refine BF16 check in CPU/GPU

* Fixes

* Renames
This commit is contained in:
Sylvain Gugger
2022-06-16 11:27:58 -04:00
committed by GitHub
parent f44e2c2b6f
commit 36d4647993
3 changed files with 29 additions and 15 deletions

View File

@@ -42,7 +42,7 @@ from transformers.testing_utils import (
slow,
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_available
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
if is_torch_available():
@@ -129,7 +129,7 @@ FP16 = "fp16"
BF16 = "bf16"
stages = [ZERO2, ZERO3]
if is_torch_bf16_available():
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
else:
dtypes = [FP16]
@@ -920,7 +920,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
@require_torch_multi_gpu
@parameterized.expand(["bf16", "fp16", "fp32"])
def test_inference(self, dtype):
if dtype == "bf16" and not is_torch_bf16_available():
if dtype == "bf16" and not is_torch_bf16_gpu_available():
self.skipTest("test requires bfloat16 hardware support")
# this is just inference, so no optimizer should be loaded