From 3601aa8fc9c85cc2c41acae357532ee3b267fb9a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 18 May 2022 16:00:47 -0700 Subject: [PATCH] [tests] fix copy-n-paste error (#17312) * [tests] fix copy-n-paste error * fix --- tests/trainer/test_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0650916c11..cb9bde6329 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1551,7 +1551,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): a = torch.ones(1000, bs) + 0.001 b = torch.ones(1000, bs) - 0.001 - # 1. with mem metrics enabled + # 1. with fp16_full_eval disabled trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False) metrics = trainer.evaluate() del trainer @@ -1572,7 +1572,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # perfect world: fp32_eval == close to zero self.assertLess(fp32_eval, 5_000) - # 2. with mem metrics disabled + # 2. with fp16_full_eval enabled trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, fp16_full_eval=True, skip_memory_metrics=False) metrics = trainer.evaluate() fp16_init = metrics["init_mem_gpu_alloc_delta"] @@ -1611,7 +1611,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): a = torch.ones(1000, bs) + 0.001 b = torch.ones(1000, bs) - 0.001 - # 1. with mem metrics enabled + # 1. with bf16_full_eval disabled trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False) metrics = trainer.evaluate() del trainer @@ -1632,7 +1632,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # perfect world: fp32_eval == close to zero self.assertLess(fp32_eval, 5_000) - # 2. with mem metrics disabled + # 2. with bf16_full_eval enabled trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False) metrics = trainer.evaluate() bf16_init = metrics["init_mem_gpu_alloc_delta"]