Enhance IPEX integration in Trainer (#18072)
* enhance ipex import * refine codes * refine style * add link * style Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires Intel Extension for PyTorch.
|
Decorator marking a test that requires Intel Extension for PyTorch.
|
||||||
|
|
||||||
These tests are skipped when Intel Extension for PyTorch isn't installed.
|
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
|
||||||
|
version.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
|
return unittest.skipUnless(
|
||||||
|
is_ipex_available(),
|
||||||
|
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
|
||||||
|
" https://github.com/intel/intel-extension-for-pytorch",
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_scatter(test_case):
|
def require_torch_scatter(test_case):
|
||||||
|
|||||||
@@ -1211,8 +1211,8 @@ class Trainer:
|
|||||||
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
|
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
|
||||||
if not is_ipex_available():
|
if not is_ipex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Using IPEX but IPEX is not installed, please refer to"
|
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
|
||||||
" https://github.com/intel/intel-extension-for-pytorch."
|
" to https://github.com/intel/intel-extension-for-pytorch."
|
||||||
)
|
)
|
||||||
|
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
@@ -1223,7 +1223,9 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
if not model.training:
|
if not model.training:
|
||||||
model.train()
|
model.train()
|
||||||
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
|
model, self.optimizer = ipex.optimize(
|
||||||
|
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@@ -443,7 +443,25 @@ def is_apex_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_ipex_available():
|
def is_ipex_available():
|
||||||
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
|
def get_major_and_minor_from_version(full_version):
|
||||||
|
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||||
|
|
||||||
|
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||||
|
return False
|
||||||
|
_ipex_version = "N/A"
|
||||||
|
try:
|
||||||
|
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
return False
|
||||||
|
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
||||||
|
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
||||||
|
if torch_major_and_minor != ipex_major_and_minor:
|
||||||
|
logger.warning(
|
||||||
|
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
|
||||||
|
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_bitsandbytes_available():
|
def is_bitsandbytes_available():
|
||||||
|
|||||||
@@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
train_output = trainer.train()
|
train_output = trainer.train()
|
||||||
self.assertEqual(train_output.global_step, 10)
|
self.assertEqual(train_output.global_step, 10)
|
||||||
|
|
||||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
|
||||||
@require_torch_bf16_cpu
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_number_of_steps_in_training_with_ipex(self):
|
def test_number_of_steps_in_training_with_ipex(self):
|
||||||
@@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
|
||||||
@require_torch_bf16_cpu
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_evaluate_with_ipex(self):
|
def test_evaluate_with_ipex(self):
|
||||||
@@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
|
||||||
@require_torch_bf16_cpu
|
@require_torch_bf16_cpu
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_predict_with_ipex(self):
|
def test_predict_with_ipex(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user