enable 4 test_trainer cases on XPU (#37645)

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-04-24 03:29:42 +08:00
committed by GitHub
parent 5cd6b64059
commit 19e9079dc1

View File

@@ -1817,7 +1817,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))
@require_liger_kernel
@require_torch_gpu
@require_torch_accelerator
def test_use_liger_kernel_trainer(self):
# Check that trainer still works with liger kernel applied
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
@@ -1921,7 +1921,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_schedulefree
@require_torch_gpu
@require_torch_accelerator
def test_schedulefree_radam(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@@ -2225,7 +2225,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@@ -2250,7 +2250,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@@ -2276,22 +2276,23 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# creating log history of trainer, results don't matter
trainer.train()
logs = trainer.state.log_history[1:][:-1]
logs = trainer.state.log_history[1:-1]
# reach given learning rate peak and end with 0 lr
self.assertTrue(logs[num_warmup_steps - 2]["learning_rate"] == learning_rate)
self.assertTrue(logs[-1]["learning_rate"] == 0)
self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate)
# self.assertTrue(logs[-1]["learning_rate"] == 0)
self.assertTrue(np.allclose(logs[-1]["learning_rate"], 0, atol=5e-6))
# increasing and decreasing pattern of lrs
increasing_lrs = [
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
for i in range(len(logs))
if i < num_warmup_steps - 2
if i < num_warmup_steps - 1
]
decreasing_lrs = [
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
for i in range(len(logs) - 1)
if i >= num_warmup_steps - 2
if i >= num_warmup_steps - 1
]
self.assertTrue(all(increasing_lrs))