From 19e9079dc1cd8db16e1ceca14a713a077106951d Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Thu, 24 Apr 2025 03:29:42 +0800 Subject: [PATCH] enable 4 test_trainer cases on XPU (#37645) Signed-off-by: YAO Matrix --- tests/trainer/test_trainer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0eb2ba6989..8e1a1c931e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1817,7 +1817,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm)) @require_liger_kernel - @require_torch_gpu + @require_torch_accelerator def test_use_liger_kernel_trainer(self): # Check that trainer still works with liger kernel applied config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) @@ -1921,7 +1921,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): _ = trainer.train() @require_schedulefree - @require_torch_gpu + @require_torch_accelerator def test_schedulefree_radam(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2225,7 +2225,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(lower_bound_pm < galore_peak_memory) @require_galore_torch - @require_torch_gpu + @require_torch_accelerator def test_galore_lr_display_without_scheduler(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2250,7 +2250,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate]) @require_galore_torch - @require_torch_gpu + @require_torch_accelerator def test_galore_lr_display_with_scheduler(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2276,22 +2276,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # creating log history of trainer, results don't matter trainer.train() - logs = trainer.state.log_history[1:][:-1] + logs = trainer.state.log_history[1:-1] # reach given learning rate peak and end with 0 lr - self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate) - self.assertTrue(logs[-1]["learning_rate"] == 0) + self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate) + # self.assertTrue(logs[-1]["learning_rate"] == 0) + self.assertTrue(np.allclose(logs[-1]["learning_rate"], 0, atol=5e-6)) # increasing and decreasing pattern of lrs increasing_lrs = [ logs[i]["learning_rate"] < logs[i + 1]["learning_rate"] for i in range(len(logs)) - if i < num_warmup_steps - 2 + if i < num_warmup_steps - 1 ] decreasing_lrs = [ logs[i]["learning_rate"] > logs[i + 1]["learning_rate"] for i in range(len(logs) - 1) - if i >= num_warmup_steps - 2 + if i >= num_warmup_steps - 1 ] self.assertTrue(all(increasing_lrs))