Extend Transformers Trainer Class to Enable PyTorch Torchscript for Inference (#17153)
* add jit mode option and model wrap * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refine code * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add ut and refine code * code refine * refine code * add inference doc * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * add cpu inference performance doc * Update perf_infer_cpu.mdx * Update perf_infer_cpu.mdx * Update performance.mdx * Update _toctree.yml * refine jit func naming * Update _toctree.yml * Delete perf_infer_gpu_one.mdx * Update perf_infer_cpu.mdx * Update docs/source/en/perf_infer_cpu.mdx Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * add none check before jit * Update docs/source/en/perf_infer_cpu.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update docs/source/en/perf_infer_cpu.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -87,6 +87,8 @@
|
|||||||
title: Training on many GPUs
|
title: Training on many GPUs
|
||||||
- local: perf_train_cpu
|
- local: perf_train_cpu
|
||||||
title: Training on CPU
|
title: Training on CPU
|
||||||
|
- local: perf_infer_cpu
|
||||||
|
title: Inference on CPU
|
||||||
- local: perf_hardware
|
- local: perf_hardware
|
||||||
title: Custom hardware for training
|
title: Custom hardware for training
|
||||||
- local: testing
|
- local: testing
|
||||||
|
|||||||
57
docs/source/en/perf_infer_cpu.mdx
Normal file
57
docs/source/en/perf_infer_cpu.mdx
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Efficient Inference on CPU
|
||||||
|
|
||||||
|
This guide focuses on inferencing large models efficiently on CPU.
|
||||||
|
|
||||||
|
## PyTorch JIT-mode (TorchScript)
|
||||||
|
TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.
|
||||||
|
Comparing to default eager mode, jit mode in PyTorch normally yields better performance for model inference from optimization methodologies like operator fusion.
|
||||||
|
|
||||||
|
For a gentle introduction to TorchScript, see the Introduction to [PyTorch TorchScript tutorial](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#tracing-modules).
|
||||||
|
|
||||||
|
### IPEX Graph Optimization with JIT-mode
|
||||||
|
Intel® Extension for PyTorch provides further optimizations in jit mode for Transformers series models. It is highly recommended for users to take advantage of Intel® Extension for PyTorch with jit mode. Some frequently used operator patterns from Transformers models are already supported in Intel® Extension for PyTorch with jit mode fusions. Those fusion patterns like Multi-head-attention fusion, Concat Linear, Linear+Add, Linear+Gelu, Add+LayerNorm fusion and etc. are enabled and perform well. The benefit of the fusion is delivered to users in a transparent fashion. According to the analysis, ~70% of most popular NLP tasks in question-answering, text-classification, and token-classification can get performance benefits with these fusion patterns for both Float32 precision and BFloat16 Mixed precision.
|
||||||
|
|
||||||
|
Check more detailed information for [IPEX Graph Optimization](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/features/graph_optimization.html).
|
||||||
|
|
||||||
|
#### IPEX installation:
|
||||||
|
|
||||||
|
IPEX release is following PyTorch, check the approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/).
|
||||||
|
|
||||||
|
### Usage of JIT-mode
|
||||||
|
To enable jit mode in Trainer, users should add `jit_mode_eval` in Trainer command arguments.
|
||||||
|
|
||||||
|
Take an example of the use cases on [Transformers question-answering](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering)
|
||||||
|
|
||||||
|
- Inference using jit mode on CPU:
|
||||||
|
<pre>python run_qa.py \
|
||||||
|
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
|
||||||
|
--dataset_name squad \
|
||||||
|
--do_eval \
|
||||||
|
--max_seq_length 384 \
|
||||||
|
--doc_stride 128 \
|
||||||
|
--output_dir /tmp/ \
|
||||||
|
--no_cuda \
|
||||||
|
<b>--jit_mode_eval </b></pre>
|
||||||
|
|
||||||
|
- Inference with IPEX using jit mode on CPU:
|
||||||
|
<pre>python run_qa.py \
|
||||||
|
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
|
||||||
|
--dataset_name squad \
|
||||||
|
--do_eval \
|
||||||
|
--max_seq_length 384 \
|
||||||
|
--doc_stride 128 \
|
||||||
|
--output_dir /tmp/ \
|
||||||
|
--no_cuda \
|
||||||
|
<b>--use_ipex \</b>
|
||||||
|
<b>--jit_mode_eval</b></pre>
|
||||||
@@ -58,7 +58,7 @@ Efficient inference with large models in a production environment can be as chal
|
|||||||
|
|
||||||
### CPU
|
### CPU
|
||||||
|
|
||||||
_Coming soon_
|
[Go to CPU inference section](perf_infer_cpu.mdx)
|
||||||
|
|
||||||
### Single GPU
|
### Single GPU
|
||||||
|
|
||||||
|
|||||||
@@ -1167,6 +1167,29 @@ class Trainer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def torch_jit_model_eval(self, model, dataloader, training=False):
|
||||||
|
if not training:
|
||||||
|
if dataloader is None:
|
||||||
|
logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
|
||||||
|
return model
|
||||||
|
jit_inputs = []
|
||||||
|
example_batch = next(iter(dataloader))
|
||||||
|
for key in example_batch:
|
||||||
|
example_tensor = torch.ones_like(example_batch[key])
|
||||||
|
jit_inputs.append(example_tensor)
|
||||||
|
jit_inputs = tuple(jit_inputs)
|
||||||
|
try:
|
||||||
|
jit_model = model.eval()
|
||||||
|
with ContextManagers([self.autocast_smart_context_manager(), torch.no_grad()]):
|
||||||
|
jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
|
||||||
|
jit_model = torch.jit.freeze(jit_model)
|
||||||
|
jit_model(**example_batch)
|
||||||
|
model = jit_model
|
||||||
|
except (RuntimeError, TypeError) as e:
|
||||||
|
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
|
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
|
||||||
if not is_ipex_available():
|
if not is_ipex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -1186,11 +1209,14 @@ class Trainer:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.use_ipex:
|
if self.args.use_ipex:
|
||||||
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
||||||
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
||||||
|
|
||||||
|
if self.args.jit_mode_eval:
|
||||||
|
model = self.torch_jit_model_eval(model, dataloader, training)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||||
@@ -2700,7 +2726,7 @@ class Trainer:
|
|||||||
self.model_wrapped = deepspeed_engine
|
self.model_wrapped = deepspeed_engine
|
||||||
self.deepspeed = deepspeed_engine
|
self.deepspeed = deepspeed_engine
|
||||||
|
|
||||||
model = self._wrap_model(self.model, training=False)
|
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
|
||||||
|
|
||||||
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
||||||
# while ``train`` is running, cast it to the right dtype first and then put on device
|
# while ``train`` is running, cast it to the right dtype first and then put on device
|
||||||
@@ -3261,7 +3287,7 @@ class Trainer:
|
|||||||
deepspeed_engine.optimizer.optimizer = None
|
deepspeed_engine.optimizer.optimizer = None
|
||||||
deepspeed_engine.lr_scheduler = None
|
deepspeed_engine.lr_scheduler = None
|
||||||
|
|
||||||
model = self._wrap_model(self.model, training=False)
|
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
|
||||||
|
|
||||||
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
||||||
# while ``train`` is running, cast it to the right dtype first and then put on device
|
# while ``train`` is running, cast it to the right dtype first and then put on device
|
||||||
|
|||||||
@@ -245,6 +245,8 @@ class TrainingArguments:
|
|||||||
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
|
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
|
||||||
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
|
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
|
||||||
seed.
|
seed.
|
||||||
|
jit_mode_eval (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to use PyTorch jit trace for inference.
|
||||||
use_ipex (`bool`, *optional*, defaults to `False`):
|
use_ipex (`bool`, *optional*, defaults to `False`):
|
||||||
Use Intel extension for PyTorch when it is available. [IPEX
|
Use Intel extension for PyTorch when it is available. [IPEX
|
||||||
installation](https://github.com/intel/intel-extension-for-pytorch).
|
installation](https://github.com/intel/intel-extension-for-pytorch).
|
||||||
@@ -625,6 +627,9 @@ class TrainingArguments:
|
|||||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||||
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
|
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
|
||||||
|
jit_mode_eval: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
|
||||||
|
)
|
||||||
use_ipex: bool = field(
|
use_ipex: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
|
|||||||
@@ -844,6 +844,47 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
def test_evaluate_with_jit(self):
|
||||||
|
trainer = get_regression_trainer(a=1.5, b=2.5, compute_metrics=AlmostAccuracy(), jit_mode_eval=True)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With a number of elements not a round multiple of the batch size
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, eval_len=66, compute_metrics=AlmostAccuracy(), jit_mode_eval=True
|
||||||
|
)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
|
# With logits preprocess
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5,
|
||||||
|
b=2.5,
|
||||||
|
compute_metrics=AlmostAccuracy(),
|
||||||
|
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
|
||||||
|
jit_mode_eval=True,
|
||||||
|
)
|
||||||
|
results = trainer.evaluate()
|
||||||
|
|
||||||
|
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
|
||||||
|
pred = 1.5 * x + 2.5
|
||||||
|
expected_loss = ((pred - y) ** 2).mean()
|
||||||
|
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||||
|
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||||
|
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||||
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_evaluate_with_ipex(self):
|
def test_evaluate_with_ipex(self):
|
||||||
@@ -930,6 +971,40 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
|
def test_predict_with_jit(self):
|
||||||
|
trainer = get_regression_trainer(a=1.5, b=2.5, jit_mode_eval=True)
|
||||||
|
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||||
|
|
||||||
|
# With a number of elements not a round multiple of the batch size
|
||||||
|
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, jit_mode_eval=True)
|
||||||
|
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||||
|
|
||||||
|
# With more than one output of the model
|
||||||
|
trainer = get_regression_trainer(a=1.5, b=2.5, double_output=True, jit_mode_eval=True)
|
||||||
|
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertEqual(len(preds), 2)
|
||||||
|
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||||
|
|
||||||
|
# With more than one output/label of the model
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
a=1.5, b=2.5, double_output=True, label_names=["labels", "labels_2"], jit_mode_eval=True
|
||||||
|
)
|
||||||
|
outputs = trainer.predict(trainer.eval_dataset)
|
||||||
|
preds = outputs.predictions
|
||||||
|
labels = outputs.label_ids
|
||||||
|
x = trainer.eval_dataset.x
|
||||||
|
self.assertEqual(len(preds), 2)
|
||||||
|
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
|
||||||
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16
|
||||||
@require_intel_extension_for_pytorch
|
@require_intel_extension_for_pytorch
|
||||||
def test_predict_with_ipex(self):
|
def test_predict_with_ipex(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user