Extend Transformers Trainer Class to Enable CPU AMP and Integrate Intel Extension for PyTorch (#17138)
* init PR * fix import ipex * minor fix on bf16 * refine optimizer * refine args notes * refine code * refine ipex optimize args * refine half_precision_backend * black format * isort format * isort format files * flake8 format * doc builder format * refine codes * remove jit and optim bits * black preview format * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refine code * refine notes * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * code refine * add ipex ut * add performance cpu doc * link to the cpu doc from main perf doc * install ipex into CI's docker * Update perf_train_cpu.mdx * Update docs/source/en/perf_train_cpu.mdx Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update perf_train_cpu.mdx * Update perf_train_cpu.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from .utils import (
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_ipex_available,
|
||||
is_librosa_available,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
@@ -282,6 +283,16 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_intel_extension_for_pytorch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Intel Extension for PyTorch.
|
||||
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_torch_scatter(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch scatter.
|
||||
@@ -476,9 +487,10 @@ def require_torch_gpu(test_case):
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
|
||||
return unittest.skipUnless(
|
||||
is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
|
||||
is_torch_bf16_available(),
|
||||
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
|
||||
)(test_case)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user