From b7d8bd378caa170b54c3a07949d9f85b73c29333 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 12 Jul 2022 12:34:09 +0800 Subject: [PATCH] Enhance IPEX integration in Trainer (#18072) * enhance ipex import * refine codes * refine style * add link * style Co-authored-by: Stas Bekman --- src/transformers/testing_utils.py | 9 +++++++-- src/transformers/trainer.py | 8 +++++--- src/transformers/utils/import_utils.py | 20 +++++++++++++++++++- tests/trainer/test_trainer.py | 3 --- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1a71e9d840..cb2ef1b0fc 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case): """ Decorator marking a test that requires Intel Extension for PyTorch. - These tests are skipped when Intel Extension for PyTorch isn't installed. + These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch + version. """ - return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case) + return unittest.skipUnless( + is_ipex_available(), + "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see" + " https://github.com/intel/intel-extension-for-pytorch", + )(test_case) def require_torch_scatter(test_case): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0bebc8626b..f038f4ae5a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1211,8 +1211,8 @@ class Trainer: def ipex_optimize_model(self, model, training=False, dtype=torch.float32): if not is_ipex_available(): raise ImportError( - "Using IPEX but IPEX is not installed, please refer to" - " https://github.com/intel/intel-extension-for-pytorch." + "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" + " to https://github.com/intel/intel-extension-for-pytorch." ) import intel_extension_for_pytorch as ipex @@ -1223,7 +1223,9 @@ class Trainer: else: if not model.training: model.train() - model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1") + model, self.optimizer = ipex.optimize( + model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" + ) return model diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 31dbb536ac..f7c44ac6e3 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -443,7 +443,25 @@ def is_apex_available(): def is_ipex_available(): - return importlib.util.find_spec("intel_extension_for_pytorch") is not None + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + _ipex_version = "N/A" + try: + _ipex_version = importlib_metadata.version("intel_extension_for_pytorch") + except importlib_metadata.PackageNotFoundError: + return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + logger.warning( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True def is_bitsandbytes_available(): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 73e7b4eeb1..15cb3cf2ce 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) - @unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12") @require_torch_bf16_cpu @require_intel_extension_for_pytorch def test_number_of_steps_in_training_with_ipex(self): @@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"] self.assertAlmostEqual(results["eval_accuracy"], expected_acc) - @unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12") @require_torch_bf16_cpu @require_intel_extension_for_pytorch def test_evaluate_with_ipex(self): @@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0])) self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1])) - @unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12") @require_torch_bf16_cpu @require_intel_extension_for_pytorch def test_predict_with_ipex(self):