Repurpose torchdynamo training args towards torch._dynamo (#20498)

* Repurpose torchdynamo training args towards torch._dynamo

* Add doc
This commit is contained in:
Sylvain Gugger
2022-11-30 11:10:45 -05:00
committed by GitHub
parent 829374e4fc
commit 08b4621899
5 changed files with 53 additions and 73 deletions

View File

@@ -1839,20 +1839,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# 4. TorchDynamo fx2trt
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
torchdynamo.reset()
# 5. TorchDynamo fx2trt-fp16
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16")
metrics = trainer.evaluate()
t1 = metrics["eval_loss"]
t2 = original_eval_loss
# fp16 has accuracy accuracy degradation
self.assertLess(np.max(np.abs(t1 - t2)), 1e-3)
torchdynamo.reset()
@require_torch_non_multi_gpu
@require_torchdynamo
def test_torchdynamo_memory(self):