Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308)
* Support compilation via Torchdynamo, AOT Autograd, NVFuser * Address comments * Lint * Stas comments - missing quality test * Lintere * Quality test * Doc lint * Reset CUDA peak mem * Add CustomTrainer * require a single gpu Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -70,6 +70,7 @@ from .utils import (
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_torchdynamo_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
@@ -464,6 +465,11 @@ else:
|
||||
jax_device = None
|
||||
|
||||
|
||||
def require_torchdynamo(test_case):
|
||||
"""Decorator marking a test that requires TorchDynamo"""
|
||||
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
|
||||
|
||||
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch."""
|
||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||
|
||||
Reference in New Issue
Block a user