Enable torchdynamo with torch_tensorrt(fx path) (#17765)

* enable fx2trt

* Update perf_train_gpu_one.mdx

* Update perf_train_gpu_one.mdx

* add lib check

* update

* format

* update

* fix import check

* fix isort

* improve doc

* refactor ctx manager

* fix isort

* black format

* isort fix

* fix format

* update args

* update black

* cleanups

* Update perf_train_gpu_one.mdx

* code refactor

* code refactor to init

* remove redundancy

* isort

* replace self.args with args

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Wei
2022-07-13 09:43:28 -07:00
committed by GitHub
parent 37aeb5787a
commit 7ea6ccc2b3
7 changed files with 88 additions and 22 deletions

View File

@@ -62,6 +62,7 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_tensorrt_fx,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torchdynamo,
@@ -1796,6 +1797,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
@require_torch_non_multi_gpu
@require_torchdynamo
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
n_gpus = get_gpu_count()
@@ -1824,6 +1826,21 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):
@@ -1849,7 +1866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
mod = MyModule()
# 1. Default - without TorchDynamo
# 1. without TorchDynamo (eager baseline)
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
trainer = CustomTrainer(model=mod)
@@ -1857,16 +1874,15 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
for _ in range(10):
orig_loss = trainer.training_step(mod, {"x": a})
torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer
# Reset the peak for another measurement
# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
orig_loss = trainer.training_step(mod, {"x": a})
orig_peak_mem = torch.cuda.max_memory_allocated()
del trainer
# 2. TorchDynamo nvfuser
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
a.grad = None
@@ -1876,7 +1892,11 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
for _ in range(10):
loss = trainer.training_step(mod, {"x": a})
# resets
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss = trainer.training_step(mod, {"x": a})
peak_mem = torch.cuda.max_memory_allocated()
del trainer