From f3598a95c7558ed051ec76bb1ddd3fbe751abcef Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Tue, 29 Jul 2025 01:51:00 -0700 Subject: [PATCH] extend more trainer test cases to XPU, all pass (#39652) extend more trainer test cases to XPU Signed-off-by: Yao, Matrix --- tests/trainer/test_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index dd0e3af4ab..3c2cc0ce89 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2520,7 +2520,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertTrue(len(decreasing_lrs) > len(increasing_lrs)) @require_torch_optimi - @require_torch_gpu + @require_torch_accelerator def test_stable_adamw(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2539,7 +2539,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): _ = trainer.train() @require_torch_optimi - @require_torch_gpu + @require_torch_accelerator def test_stable_adamw_extra_args(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2561,7 +2561,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): _ = trainer.train() @require_torch_optimi - @require_torch_gpu + @require_torch_accelerator def test_stable_adamw_lr_display_without_scheduler(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2586,7 +2586,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate]) @require_torch_optimi - @require_torch_gpu + @require_torch_accelerator def test_stable_adamw_lr_display_with_scheduler(self): config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) tiny_llama = LlamaForCausalLM(config) @@ -2615,19 +2615,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): logs = trainer.state.log_history[1:][:-1] # reach given learning rate peak and end with 0 lr - self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate) - self.assertTrue(logs[-1]["learning_rate"] == 0) + self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate) + self.assertTrue(np.allclose(logs[-1]["learning_rate"], 0, atol=5e-6)) # increasing and decreasing pattern of lrs increasing_lrs = [ logs[i]["learning_rate"] < logs[i + 1]["learning_rate"] for i in range(len(logs)) - if i < num_warmup_steps - 2 + if i < num_warmup_steps - 1 ] decreasing_lrs = [ logs[i]["learning_rate"] > logs[i + 1]["learning_rate"] for i in range(len(logs) - 1) - if i >= num_warmup_steps - 2 + if i >= num_warmup_steps - 1 ] self.assertTrue(all(increasing_lrs))