Device agnostic trainer testing (#27131)

This commit is contained in:
Hz, Ji
2023-10-31 02:16:40 +08:00
committed by GitHub
parent 84724efd10
commit 5bbf671276
3 changed files with 87 additions and 46 deletions

View File

@@ -26,16 +26,17 @@ from transformers.testing_utils import (
CaptureStderr,
ExtendSysPath,
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_gpu_count,
get_torch_dist_unique_port,
require_apex,
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
require_torch_multi_accelerator,
require_torch_non_multi_accelerator,
slow,
torch_device,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
@@ -89,17 +90,17 @@ class TestTrainerExt(TestCasePlus):
assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
@require_torch_non_multi_gpu
@require_torch_non_multi_accelerator
def test_run_seq2seq_no_dist(self):
self.run_seq2seq_quick()
# verify that the trainer can handle non-distributed with n_gpu > 1
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_run_seq2seq_dp(self):
self.run_seq2seq_quick(distributed=False)
# verify that the trainer can handle distributed with n_gpu > 1
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
@@ -120,7 +121,7 @@ class TestTrainerExt(TestCasePlus):
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
@parameterized.expand(["base", "low", "high", "mixed"])
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_trainer_log_level_replica(self, experiment_id):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments = {
@@ -331,7 +332,7 @@ class TestTrainerExt(TestCasePlus):
if distributed:
if n_gpus_to_use is None:
n_gpus_to_use = get_gpu_count()
n_gpus_to_use = backend_device_count(torch_device)
master_port = get_torch_dist_unique_port()
distributed_args = f"""
-m torch.distributed.run