Enhance IPEX integration in Trainer (#18072)

* enhance ipex import

* refine codes

* refine style

* add link

* style

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
jianan-gu
2022-07-12 12:34:09 +08:00
committed by GitHub
parent a462fc9232
commit b7d8bd378c
4 changed files with 31 additions and 9 deletions

View File

@@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
@@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
@@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):