[testing] ensure concurrent pytest workers use a unique port for torch.dist (#12166)
* ensure concurrent pytest workers use a unique port for torch.distributed.launch * reword
This commit is contained in:
@@ -16,7 +16,12 @@ import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -64,6 +69,7 @@ class TestTrainerDistributed(TestCasePlus):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_trainer_distributed.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
Reference in New Issue
Block a user