Enable torchdynamo with torch_tensorrt(fx path) (#17765)
* enable fx2trt * Update perf_train_gpu_one.mdx * Update perf_train_gpu_one.mdx * add lib check * update * format * update * fix import check * fix isort * improve doc * refactor ctx manager * fix isort * black format * isort fix * fix format * update args * update black * cleanups * Update perf_train_gpu_one.mdx * code refactor * code refactor to init * remove redundancy * isort * replace self.args with args Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@@ -62,6 +62,7 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tensorrt_fx,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
require_torchdynamo,
|
||||
@@ -1796,6 +1797,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
@require_torch_tensorrt_fx
|
||||
def test_torchdynamo_full_eval(self):
|
||||
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
|
||||
n_gpus = get_gpu_count()
|
||||
@@ -1824,6 +1826,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
metrics = trainer.evaluate()
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
|
||||
# 4. TorchDynamo fx2trt
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
|
||||
metrics = trainer.evaluate()
|
||||
t1 = metrics["eval_loss"]
|
||||
t2 = original_eval_loss
|
||||
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
|
||||
|
||||
# 5. TorchDynamo fx2trt-fp16
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
|
||||
metrics = trainer.evaluate()
|
||||
t1 = metrics["eval_loss"]
|
||||
t2 = original_eval_loss
|
||||
# fp16 has accuracy accuracy degradation
|
||||
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torchdynamo
|
||||
def test_torchdynamo_memory(self):
|
||||
@@ -1849,7 +1866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
mod = MyModule()
|
||||
|
||||
# 1. Default - without TorchDynamo
|
||||
# 1. without TorchDynamo (eager baseline)
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
trainer = CustomTrainer(model=mod)
|
||||
@@ -1857,16 +1874,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
for _ in range(10):
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# Reset the peak for another measurement
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
orig_loss = trainer.training_step(mod, {"x": a})
|
||||
orig_peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
# 2. TorchDynamo nvfuser
|
||||
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
|
||||
a.grad = None
|
||||
@@ -1876,7 +1892,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
for _ in range(10):
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
|
||||
# resets
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
loss = trainer.training_step(mod, {"x": a})
|
||||
peak_mem = torch.cuda.max_memory_allocated()
|
||||
del trainer
|
||||
|
||||
Reference in New Issue
Block a user