Enable torchdynamo with torch_tensorrt(fx path) (#17765)

* enable fx2trt

* Update perf_train_gpu_one.mdx

* Update perf_train_gpu_one.mdx

* add lib check

* update

* format

* update

* fix import check

* fix isort

* improve doc

* refactor ctx manager

* fix isort

* black format

* isort fix

* fix format

* update args

* update black

* cleanups

* Update perf_train_gpu_one.mdx

* code refactor

* code refactor to init

* remove redundancy

* isort

* replace self.args with args

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Wei
2022-07-13 09:43:28 -07:00
committed by GitHub
parent 37aeb5787a
commit 7ea6ccc2b3
7 changed files with 88 additions and 22 deletions

View File

@@ -71,6 +71,7 @@ from .utils import (
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
@@ -499,6 +500,11 @@ def require_torchdynamo(test_case):
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
def require_torch_tensorrt_fx(test_case):
"""Decorator marking a test that requires Torch-TensorRT FX"""
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)