Device agnostic trainer testing (#27131)
This commit is contained in:
@@ -26,16 +26,17 @@ from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
ExtendSysPath,
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_apex,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_non_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
@@ -89,17 +90,17 @@ class TestTrainerExt(TestCasePlus):
|
||||
assert isinstance(last_step_stats["eval_bleu"], float)
|
||||
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
|
||||
|
||||
@require_torch_non_multi_gpu
|
||||
@require_torch_non_multi_accelerator
|
||||
def test_run_seq2seq_no_dist(self):
|
||||
self.run_seq2seq_quick()
|
||||
|
||||
# verify that the trainer can handle non-distributed with n_gpu > 1
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_dp(self):
|
||||
self.run_seq2seq_quick(distributed=False)
|
||||
|
||||
# verify that the trainer can handle distributed with n_gpu > 1
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_run_seq2seq_ddp(self):
|
||||
self.run_seq2seq_quick(distributed=True)
|
||||
|
||||
@@ -120,7 +121,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@parameterized.expand(["base", "low", "high", "mixed"])
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_trainer_log_level_replica(self, experiment_id):
|
||||
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
|
||||
experiments = {
|
||||
@@ -331,7 +332,7 @@ class TestTrainerExt(TestCasePlus):
|
||||
|
||||
if distributed:
|
||||
if n_gpus_to_use is None:
|
||||
n_gpus_to_use = get_gpu_count()
|
||||
n_gpus_to_use = backend_device_count(torch_device)
|
||||
master_port = get_torch_dist_unique_port()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.run
|
||||
|
||||
Reference in New Issue
Block a user