[testing] ensure concurrent pytest workers use a unique port for torch.dist (#12166)

* ensure concurrent pytest workers use a unique port for torch.distributed.launch

* reword
This commit is contained in:
Stas Bekman
2021-06-15 11:12:59 -07:00
committed by GitHub
parent b9d66f4c4b
commit 6e7cc5cc51
3 changed files with 32 additions and 1 deletions

View File

@@ -16,7 +16,12 @@ import sys
from typing import Dict
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
from transformers.utils import logging
@@ -64,6 +69,7 @@ class TestTrainerDistributed(TestCasePlus):
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()