deprecate is_torch_bf16_available (#17738)

* deprecate is_torch_bf16_available

* address suggestions
This commit is contained in:
Stas Bekman
2022-06-20 05:40:11 -07:00
committed by GitHub
parent 132402d752
commit a2d34b7c04
5 changed files with 47 additions and 19 deletions

View File

@@ -306,7 +306,7 @@ stages = [ZERO2, ZERO3]
#
# dtypes = [FP16]
# so just hardcoding --fp16 for now
# if is_torch_bf16_available():
# if is_torch_bf16_gpu_available():
# dtypes += [BF16]

View File

@@ -57,7 +57,8 @@ from transformers.testing_utils import (
require_sigopt,
require_tokenizers,
require_torch,
require_torch_bf16,
require_torch_bf16_cpu,
require_torch_bf16_gpu,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
@@ -554,7 +555,7 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
@require_torch_gpu
@require_torch_bf16
@require_torch_bf16_gpu
def test_mixed_bf16(self):
# very basic test
@@ -641,7 +642,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@require_torch_bf16
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
for mix_bf16 in [True, False]:
@@ -885,7 +886,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@require_torch_bf16
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
for mix_bf16 in [True, False]:
@@ -1005,7 +1006,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@require_torch_bf16
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
for mix_bf16 in [True, False]:
@@ -1888,7 +1889,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertGreater(orig_peak_mem, peak_mem * 2)
@require_torch_gpu
@require_torch_bf16
@require_torch_bf16_gpu
def test_bf16_full_eval(self):
# note: most of the logic is the same as test_fp16_full_eval