deprecate is_torch_bf16_available (#17738)
* deprecate is_torch_bf16_available * address suggestions
This commit is contained in:
@@ -67,7 +67,8 @@ from .utils import (
|
|||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_available,
|
is_torch_bf16_cpu_available,
|
||||||
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
@@ -486,11 +487,19 @@ def require_torch_gpu(test_case):
|
|||||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_bf16(test_case):
|
def require_torch_bf16_gpu(test_case):
|
||||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
|
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
|
||||||
return unittest.skipUnless(
|
return unittest.skipUnless(
|
||||||
is_torch_bf16_available(),
|
is_torch_bf16_gpu_available(),
|
||||||
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
|
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_bf16_cpu(test_case):
|
||||||
|
"""Decorator marking a test that requires torch>=1.10, using CPU."""
|
||||||
|
return unittest.skipUnless(
|
||||||
|
is_torch_bf16_cpu_available(),
|
||||||
|
"test requires torch>=1.10, using CPU",
|
||||||
)(test_case)
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,8 @@ from .utils import (
|
|||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_available,
|
is_torch_bf16_cpu_available,
|
||||||
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_tf32_available,
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -1036,14 +1037,23 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
self.half_precision_backend = self.fp16_backend
|
self.half_precision_backend = self.fp16_backend
|
||||||
|
|
||||||
if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available() and not self.no_cuda:
|
if self.bf16 or self.bf16_full_eval:
|
||||||
|
|
||||||
|
if self.no_cuda and not is_torch_bf16_cpu_available():
|
||||||
|
# cpu
|
||||||
|
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
|
||||||
|
elif not is_torch_bf16_gpu_available():
|
||||||
|
# gpu
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Your setup doesn't support bf16. You need torch>=1.10, using Ampere GPU with cuda>=11.0 or using CPU"
|
"Your setup doesn't support bf16/gpu. You need torch>=1.10, using Ampere GPU with cuda>=11.0"
|
||||||
" (no_cuda)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fp16 and self.bf16:
|
if self.fp16 and self.bf16:
|
||||||
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||||
|
|
||||||
|
if self.fp16_full_eval and self.bf16_full_eval:
|
||||||
|
raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
|
||||||
|
|
||||||
if self.bf16:
|
if self.bf16:
|
||||||
if self.half_precision_backend == "apex":
|
if self.half_precision_backend == "apex":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import importlib.util
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
@@ -323,7 +324,14 @@ def is_torch_bf16_cpu_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_torch_bf16_available():
|
def is_torch_bf16_available():
|
||||||
return is_torch_bf16_cpu_available() or is_torch_bf16_gpu_available()
|
# the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util
|
||||||
|
# has become ambiguous and therefore deprecated
|
||||||
|
warnings.warn(
|
||||||
|
"The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
|
||||||
|
"or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return is_torch_bf16_gpu_available()
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tf32_available():
|
def is_torch_tf32_available():
|
||||||
|
|||||||
@@ -306,7 +306,7 @@ stages = [ZERO2, ZERO3]
|
|||||||
#
|
#
|
||||||
# dtypes = [FP16]
|
# dtypes = [FP16]
|
||||||
# so just hardcoding --fp16 for now
|
# so just hardcoding --fp16 for now
|
||||||
# if is_torch_bf16_available():
|
# if is_torch_bf16_gpu_available():
|
||||||
# dtypes += [BF16]
|
# dtypes += [BF16]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,8 @@ from transformers.testing_utils import (
|
|||||||
require_sigopt,
|
require_sigopt,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_bf16,
|
require_torch_bf16_cpu,
|
||||||
|
require_torch_bf16_gpu,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
@@ -554,7 +555,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_bf16
|
@require_torch_bf16_gpu
|
||||||
def test_mixed_bf16(self):
|
def test_mixed_bf16(self):
|
||||||
|
|
||||||
# very basic test
|
# very basic test
|
||||||
@@ -641,7 +642,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, 10)
|
self.assertEqual(train_output.global_step, 10)
|
||||||
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_number_of_steps_in_training_with_ipex(self):
|
def test_number_of_steps_in_training_with_ipex(self):
|
||||||
for mix_bf16 in [True, False]:
|
for mix_bf16 in [True, False]:
|
||||||
@@ -885,7 +886,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_evaluate_with_ipex(self):
|
def test_evaluate_with_ipex(self):
|
||||||
for mix_bf16 in [True, False]:
|
for mix_bf16 in [True, False]:
|
||||||
@@ -1005,7 +1006,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_predict_with_ipex(self):
|
def test_predict_with_ipex(self):
|
||||||
for mix_bf16 in [True, False]:
|
for mix_bf16 in [True, False]:
|
||||||
@@ -1888,7 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertGreater(orig_peak_mem, peak_mem * 2)
|
self.assertGreater(orig_peak_mem, peak_mem * 2)
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torch_bf16
|
@require_torch_bf16_gpu
|
||||||
def test_bf16_full_eval(self):
|
def test_bf16_full_eval(self):
|
||||||
# note: most of the logic is the same as test_fp16_full_eval
|
# note: most of the logic is the same as test_fp16_full_eval
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user