From 7caa57e85e74a3cc7c16a115fb3fff78aac7a609 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Thu, 15 May 2025 20:17:44 +0800 Subject: [PATCH] enable trainer test cases on xpu (#38138) * enable trainer test cases on xpu Signed-off-by: Matrix Yao * fix style Signed-off-by: Matrix Yao --------- Signed-off-by: Matrix Yao --- src/transformers/testing_utils.py | 2 +- tests/trainer/test_trainer.py | 3 +-- tests/trainer/test_trainer_distributed_loss.py | 9 +++++---- tests/trainer/test_trainer_distributed_worker_seed.py | 10 ++++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 5ab348377d..6a16581e81 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2039,7 +2039,7 @@ class TestCasePlus(unittest.TestCase): """ env = os.environ.copy() - paths = [self.src_dir_str] + paths = [self.repo_root_dir_str, self.src_dir_str] if "/examples" in self.test_file_dir_str: paths.append(self.examples_dir_str) else: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f687d2d88e..b5fb3d64e7 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -97,7 +97,6 @@ from transformers.testing_utils import ( require_torch_fp16, require_torch_gpu, require_torch_multi_accelerator, - require_torch_multi_gpu, require_torch_non_multi_accelerator, require_torch_non_multi_gpu, require_torch_tensorrt_fx, @@ -3766,7 +3765,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): train_output = trainer.train() self.assertEqual(train_output.global_step, int(self.n_epochs)) - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_num_batches_in_training_with_gradient_accumulation(self): with tempfile.TemporaryDirectory() as tmp_dir: for num_train_epochs in [1, 2]: diff --git a/tests/trainer/test_trainer_distributed_loss.py b/tests/trainer/test_trainer_distributed_loss.py index 925b7b8ba5..9bae7c9265 100644 --- a/tests/trainer/test_trainer_distributed_loss.py +++ b/tests/trainer/test_trainer_distributed_loss.py @@ -1,7 +1,6 @@ import json import datasets -import torch from tests.trainer.test_trainer import StoreLossCallback from transformers import ( @@ -15,16 +14,18 @@ from transformers import ( ) from transformers.testing_utils import ( TestCasePlus, + backend_device_count, execute_subprocess_async, get_torch_dist_unique_port, - require_torch_multi_gpu, + require_torch_multi_accelerator, + torch_device, ) class TestTrainerDistributedLoss(TestCasePlus): - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_trainer(self): - device_count = torch.cuda.device_count() + device_count = backend_device_count(torch_device) min_bs = 1 output_dir = self.get_auto_remove_tmp_dir() for gpu_num, enable, bs, name in ( diff --git a/tests/trainer/test_trainer_distributed_worker_seed.py b/tests/trainer/test_trainer_distributed_worker_seed.py index f4fececf10..3fa625af74 100644 --- a/tests/trainer/test_trainer_distributed_worker_seed.py +++ b/tests/trainer/test_trainer_distributed_worker_seed.py @@ -14,9 +14,11 @@ from transformers import ( ) from transformers.testing_utils import ( TestCasePlus, + backend_device_count, execute_subprocess_async, get_torch_dist_unique_port, - require_torch_multi_gpu, + require_torch_multi_accelerator, + torch_device, ) @@ -47,7 +49,7 @@ class DummyModel(nn.Module): self.fc = nn.Linear(3, 1) def forward(self, x): - local_tensor = torch.tensor(x, device="cuda") + local_tensor = torch.tensor(x, device=torch_device) gathered = gather_from_all_gpus(local_tensor, dist.get_world_size()) assert not all(torch.allclose(t, gathered[0]) for t in gathered[1:]) y = self.fc(x) @@ -55,9 +57,9 @@ class DummyModel(nn.Module): class TestTrainerDistributedWorkerSeed(TestCasePlus): - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_trainer(self): - device_count = torch.cuda.device_count() + device_count = backend_device_count(torch_device) output_dir = self.get_auto_remove_tmp_dir() distributed_args = f"""--nproc_per_node={device_count} --master_port={get_torch_dist_unique_port()}