[testing] ensure concurrent pytest workers use a unique port for torch.dist (#12166)
* ensure concurrent pytest workers use a unique port for torch.distributed.launch * reword
This commit is contained in:
@@ -25,6 +25,7 @@ from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
@@ -223,9 +224,11 @@ class TestTrainerExt(TestCasePlus):
|
||||
|
||||
if distributed:
|
||||
n_gpu = get_gpu_count()
|
||||
master_port = get_torch_dist_unique_port()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
--master_port={master_port}
|
||||
{self.examples_dir_str}/pytorch/translation/run_translation.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
|
||||
@@ -16,7 +16,12 @@ import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -64,6 +69,7 @@ class TestTrainerDistributed(TestCasePlus):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_trainer_distributed.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
Reference in New Issue
Block a user