Repurpose torchdynamo training args towards torch._dynamo (#20498)
* Repurpose torchdynamo training args towards torch._dynamo * Add doc
This commit is contained in:
@@ -720,16 +720,25 @@ Another use case for training on many GPUs is if the model does not fit on a sin
|
|||||||
|
|
||||||
## Inference with torchdynamo
|
## Inference with torchdynamo
|
||||||
|
|
||||||
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost.
|
TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. You can choose one option below for performance boost.
|
||||||
|
|
||||||
```
|
TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
|
||||||
TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost
|
or `torchdynamo.list_backends()` each of which with its optional dependencies.
|
||||||
TrainingArguments(torchdynamo="nvfuser") #enable nvfuser
|
|
||||||
TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32
|
|
||||||
TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16
|
|
||||||
```
|
|
||||||
|
|
||||||
This feature involves 3 different libraries. To install them, please follow the instructions below:
|
Some of the most commonly used backends are
|
||||||
- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup)
|
|
||||||
- [Functorch installation](https://github.com/pytorch/functorch#install)
|
**Debugging backends**:
|
||||||
- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation)
|
* `dynamo.optimize("eager")` - Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.
|
||||||
|
* `dynamo.optimize("aot_eager")` - Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
|
||||||
|
|
||||||
|
**Training & inference backends**:
|
||||||
|
* `dynamo.optimize("inductor")` - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels [Read more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
|
||||||
|
* `dynamo.optimize("nvfuser")` - nvFuser with TorchScript. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
|
||||||
|
* `dynamo.optimize("aot_nvfuser")` - nvFuser with AotAutograd. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
|
||||||
|
* `dynamo.optimize("aot_cudagraphs")` - cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)
|
||||||
|
|
||||||
|
**Inference-only backend**s:
|
||||||
|
* `dynamo.optimize("ofi")` - Uses Torchscript optimize_for_inference. [Read more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)
|
||||||
|
* `dynamo.optimize("fx2trt")` - Uses Nvidia TensorRT for inference optimizations. [Read more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
|
||||||
|
* `dynamo.optimize("onnxrt")` - Uses ONNXRT for inference on CPU/GPU. [Read more](https://onnxruntime.ai/)
|
||||||
|
* `dynamo.optimize("ipex")` - Uses IPEX for inference on CPU. [Read more](https://github.com/intel/intel-extension-for-pytorch)
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ from .utils import (
|
|||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_tensorrt_fx_available,
|
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchdynamo_available,
|
is_torchdynamo_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -637,32 +636,8 @@ class Trainer:
|
|||||||
self._memory_tracker.stop_and_update_metrics()
|
self._memory_tracker.stop_and_update_metrics()
|
||||||
|
|
||||||
# torchdynamo
|
# torchdynamo
|
||||||
if args.torchdynamo:
|
if args.torchdynamo is not None and not is_torchdynamo_available():
|
||||||
if not is_torchdynamo_available():
|
raise RuntimeError("Using torchdynamo requires a nighly install of PyTorch.")
|
||||||
raise RuntimeError("Torchdynamo is not installed.")
|
|
||||||
import torchdynamo
|
|
||||||
from torchdynamo.optimizations import backends
|
|
||||||
|
|
||||||
def get_ctx():
|
|
||||||
# Normal
|
|
||||||
if args.torchdynamo == "eager":
|
|
||||||
return torchdynamo.optimize("eager")
|
|
||||||
elif args.torchdynamo == "nvfuser":
|
|
||||||
return torchdynamo.optimize("aot_nvfuser")
|
|
||||||
# TensorRT
|
|
||||||
if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]:
|
|
||||||
if not is_torch_tensorrt_fx_available():
|
|
||||||
raise RuntimeError("Torch-TensorRT FX path is not installed.")
|
|
||||||
if args.torchdynamo == "fx2trt-fp16":
|
|
||||||
return torchdynamo.optimize(backends.fx2trt_compiler_fp16)
|
|
||||||
elif args.torchdynamo == "fx2trt":
|
|
||||||
return torchdynamo.optimize(backends.fx2trt_compiler)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.")
|
|
||||||
|
|
||||||
self.ctx_manager_torchdynamo = get_ctx()
|
|
||||||
else:
|
|
||||||
self.ctx_manager_torchdynamo = contextlib.nullcontext()
|
|
||||||
|
|
||||||
def add_callback(self, callback):
|
def add_callback(self, callback):
|
||||||
"""
|
"""
|
||||||
@@ -1339,6 +1314,10 @@ class Trainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torchdynamo is not None:
|
||||||
|
import torch._dynamo as dynamo
|
||||||
|
|
||||||
|
model = dynamo.optimize(self.args.torchdynamo)(model)
|
||||||
if self.args.use_ipex:
|
if self.args.use_ipex:
|
||||||
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
|
||||||
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
model = self.ipex_optimize_model(model, training, dtype=dtype)
|
||||||
@@ -2494,18 +2473,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
A helper wrapper to group together context managers.
|
A helper wrapper to group together context managers.
|
||||||
"""
|
"""
|
||||||
return ContextManagers(
|
return self.autocast_smart_context_manager()
|
||||||
[
|
|
||||||
self.torchdynamo_smart_context_manager(),
|
|
||||||
self.autocast_smart_context_manager(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def torchdynamo_smart_context_manager(self):
|
|
||||||
"""
|
|
||||||
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
|
|
||||||
"""
|
|
||||||
return self.ctx_manager_torchdynamo
|
|
||||||
|
|
||||||
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
|
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -73,6 +73,20 @@ log_levels = logging.get_log_levels_dict().copy()
|
|||||||
trainer_log_levels = dict(**log_levels, passive=-1)
|
trainer_log_levels = dict(**log_levels, passive=-1)
|
||||||
|
|
||||||
|
|
||||||
|
DYNAMO_BACKENDS = [
|
||||||
|
"eager",
|
||||||
|
"aot_eager",
|
||||||
|
"inductor",
|
||||||
|
"nvfuser",
|
||||||
|
"aot_nvfuser",
|
||||||
|
"aot_cudagraphs",
|
||||||
|
"ofi",
|
||||||
|
"fx2trt",
|
||||||
|
"onnxrt",
|
||||||
|
"ipex",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def default_logdir() -> str:
|
def default_logdir() -> str:
|
||||||
"""
|
"""
|
||||||
Same default as PyTorch
|
Same default as PyTorch
|
||||||
@@ -485,8 +499,8 @@ class TrainingArguments:
|
|||||||
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
|
||||||
distributed training
|
distributed training
|
||||||
torchdynamo (`str`, *optional*):
|
torchdynamo (`str`, *optional*):
|
||||||
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
|
If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
|
||||||
"nvfuser]. This is an experimental API and subject to change.
|
`"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
|
||||||
ray_scope (`str`, *optional*, defaults to `"last"`):
|
ray_scope (`str`, *optional*, defaults to `"last"`):
|
||||||
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
|
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
|
||||||
then use the last checkpoint of all trials, compare those, and select the best one. However, other options
|
then use the last checkpoint of all trials, compare those, and select the best one. However, other options
|
||||||
@@ -969,15 +983,8 @@ class TrainingArguments:
|
|||||||
torchdynamo: Optional[str] = field(
|
torchdynamo: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": "Sets up the backend compiler for TorchDynamo.",
|
||||||
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
|
"choices": DYNAMO_BACKENDS,
|
||||||
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
|
|
||||||
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
|
|
||||||
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
|
|
||||||
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
|
|
||||||
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
|
|
||||||
),
|
|
||||||
"choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"],
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ray_scope: Optional[str] = field(
|
ray_scope: Optional[str] = field(
|
||||||
|
|||||||
@@ -445,7 +445,14 @@ def is_torch_tpu_available(check_device=True):
|
|||||||
|
|
||||||
|
|
||||||
def is_torchdynamo_available():
|
def is_torchdynamo_available():
|
||||||
return importlib.util.find_spec("torchdynamo") is not None
|
if not is_torch_available():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import torch._dynamo as dynamo # noqa: F401
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tensorrt_fx_available():
|
def is_torch_tensorrt_fx_available():
|
||||||
|
|||||||
@@ -1839,20 +1839,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
# 4. TorchDynamo fx2trt
|
# 4. TorchDynamo fx2trt
|
||||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
t1 = metrics["eval_loss"]
|
|
||||||
t2 = original_eval_loss
|
|
||||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||||
torchdynamo.reset()
|
torchdynamo.reset()
|
||||||
|
|
||||||
# 5. TorchDynamo fx2trt-fp16
|
|
||||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
|
||||||
metrics = trainer.evaluate()
|
|
||||||
t1 = metrics["eval_loss"]
|
|
||||||
t2 = original_eval_loss
|
|
||||||
# fp16 has accuracy accuracy degradation
|
|
||||||
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
|
||||||
torchdynamo.reset()
|
|
||||||
|
|
||||||
@require_torch_non_multi_gpu
|
@require_torch_non_multi_gpu
|
||||||
@require_torchdynamo
|
@require_torchdynamo
|
||||||
def test_torchdynamo_memory(self):
|
def test_torchdynamo_memory(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user