Extend Transformers Trainer Class to Enable CPU AMP and Integrate Intel Extension for PyTorch (#17138)

* init PR

* fix import ipex

* minor fix on bf16

* refine optimizer

* refine args notes

* refine code

* refine ipex optimize args

* refine half_precision_backend

* black format

* isort format

* isort format files

* flake8 format

* doc builder format

* refine codes

* remove jit and optim bits

* black preview format

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refine code

* refine notes

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* code refine

* add ipex ut

* add performance cpu doc

* link to the cpu doc from main perf doc

* install ipex into CI's docker

* Update perf_train_cpu.mdx

* Update docs/source/en/perf_train_cpu.mdx

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update perf_train_cpu.mdx

* Update perf_train_cpu.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
jianan-gu
2022-06-08 21:41:57 +08:00
committed by GitHub
parent ae7bae8fe7
commit 34097b3304
10 changed files with 336 additions and 42 deletions

View File

@@ -47,6 +47,7 @@ from .utils import (
is_faiss_available,
is_flax_available,
is_ftfy_available,
is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@@ -282,6 +283,16 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_intel_extension_for_pytorch(test_case):
"""
Decorator marking a test that requires Intel Extension for PyTorch.
These tests are skipped when Intel Extension for PyTorch isn't installed.
"""
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
def require_torch_scatter(test_case):
"""
Decorator marking a test that requires PyTorch scatter.
@@ -476,9 +487,10 @@ def require_torch_gpu(test_case):
def require_torch_bf16(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
return unittest.skipUnless(
is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
is_torch_bf16_available(),
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
)(test_case)