From 6e7cc5cc51cce242ad9a1b13a2e205ab168730d2 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 15 Jun 2021 11:12:59 -0700 Subject: [PATCH] [testing] ensure concurrent pytest workers use a unique port for torch.dist (#12166) * ensure concurrent pytest workers use a unique port for torch.distributed.launch * reword --- src/transformers/testing_utils.py | 22 ++++++++++++++++++++++ tests/extended/test_trainer_ext.py | 3 +++ tests/test_trainer_distributed.py | 8 +++++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 9bfb972217..ca607c3301 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1249,6 +1249,28 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False return result +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under ``pytest-xdist``'s concurrent workers ``pytest -n N`` regime, + or 0 if ``-n 1`` or ``pytest-xdist`` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to ``torch.distributed.launch``'s ``--master_port`` argument. + + Under ``pytest-xdist`` it adds a delta number based on a worker id so that concurrent tests don't try to use the + same port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + def nested_simplify(obj, decimals=3): """ Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test diff --git a/tests/extended/test_trainer_ext.py b/tests/extended/test_trainer_ext.py index 4cf16549c7..93ef0ddb55 100644 --- a/tests/extended/test_trainer_ext.py +++ b/tests/extended/test_trainer_ext.py @@ -25,6 +25,7 @@ from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, get_gpu_count, + get_torch_dist_unique_port, require_torch_gpu, require_torch_multi_gpu, require_torch_non_multi_gpu, @@ -223,9 +224,11 @@ class TestTrainerExt(TestCasePlus): if distributed: n_gpu = get_gpu_count() + master_port = get_torch_dist_unique_port() distributed_args = f""" -m torch.distributed.launch --nproc_per_node={n_gpu} + --master_port={master_port} {self.examples_dir_str}/pytorch/translation/run_translation.py """.split() cmd = [sys.executable] + distributed_args + args diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index 4f455c7dae..b40526c6de 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -16,7 +16,12 @@ import sys from typing import Dict from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available -from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_torch_multi_gpu, +) from transformers.utils import logging @@ -64,6 +69,7 @@ class TestTrainerDistributed(TestCasePlus): distributed_args = f""" -m torch.distributed.launch --nproc_per_node={torch.cuda.device_count()} + --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_trainer_distributed.py """.split() output_dir = self.get_auto_remove_tmp_dir()