Enable torchdynamo with torch_tensorrt(fx path) (#17765)
* enable fx2trt * Update perf_train_gpu_one.mdx * Update perf_train_gpu_one.mdx * add lib check * update * format * update * fix import check * fix isort * improve doc * refactor ctx manager * fix isort * black format * isort fix * fix format * update args * update black * cleanups * Update perf_train_gpu_one.mdx * code refactor * code refactor to init * remove redundancy * isort * replace self.args with args Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -71,6 +71,7 @@ from .utils import (
|
||||
is_torch_available,
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_tensorrt_fx_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
@@ -499,6 +500,11 @@ def require_torchdynamo(test_case):
|
||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||
|
||||
|
||||
def require_torch_tensorrt_fx(test_case):
|
||||
"""Decorator marking a test that requires Torch-TensorRT FX"""
|
||||
return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
Reference in New Issue
Block a user