[testing] port test_trainer_distributed to distributed pytest + TestCasePlus enhancements (#8107)
* move the helper code into testing_utils * port test_trainer_distributed to work with pytest * improve docs * simplify notes * doc * doc * style * doc * further improvements * torch might not be available * real fix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -470,7 +470,7 @@ This is still under development but you can study 2 different tests that perform
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/test_finetune_trainer.py>`__ - a normal
|
||||
(non-PL) test
|
||||
|
||||
To jump right into the execution point, search for the ``execute_async_std`` function in those tests.
|
||||
To jump right into the execution point, search for the ``execute_subprocess_async`` function in those tests.
|
||||
|
||||
You will need at least 2 GPUs to see these tests in action:
|
||||
|
||||
@@ -646,6 +646,55 @@ as in the previous example.
|
||||
|
||||
|
||||
|
||||
Files and directories
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In tests often we need to know where things are relative to the current test file, and it's not trivial since the test
|
||||
could be invoked from more than one directory or could reside in sub-directories with different depths. A helper class
|
||||
:obj:`transformers.test_utils.TestCasePlus` solves this problem by sorting out all the basic paths and provides easy
|
||||
accessors to them:
|
||||
|
||||
* ``pathlib`` objects (all fully resolved):
|
||||
|
||||
- ``test_file_path`` - the current test file path, i.e. ``__file__``
|
||||
- ``test_file_dir`` - the directory containing the current test file
|
||||
- ``tests_dir`` - the directory of the ``tests`` test suite
|
||||
- ``examples_dir`` - the directory of the ``examples`` test suite
|
||||
- ``repo_root_dir`` - the directory of the repository
|
||||
- ``src_dir`` - the directory of ``src`` (i.e. where the ``transformers`` sub-dir resides)
|
||||
|
||||
* stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects:
|
||||
|
||||
- ``test_file_path_str``
|
||||
- ``test_file_dir_str``
|
||||
- ``tests_dir_str``
|
||||
- ``examples_dir_str``
|
||||
- ``repo_root_dir_str``
|
||||
- ``src_dir_str``
|
||||
|
||||
To start using those all you need is to make sure that the test resides in a subclass of
|
||||
:obj:`transformers.test_utils.TestCasePlus`. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
class PathExampleTest(TestCasePlus):
|
||||
def test_something_involving_local_locations(self):
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
|
||||
If you don't need to manipulated paths via ``pathlib`` or you just need a path as a string, you can always invoked
|
||||
``str()`` on the ``pathlib`` oboject or use the accessors ending with ``_str``. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
class PathExampleTest(TestCasePlus):
|
||||
def test_something_involving_stringified_locations(self):
|
||||
examples_dir = self.examples_dir_str
|
||||
|
||||
|
||||
|
||||
|
||||
Temporary files and directories
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -1008,6 +1057,24 @@ If you want to test the impact of environment variables for a specific test you
|
||||
def test_env_override(self):
|
||||
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
|
||||
|
||||
At times an external program needs to be called, which requires setting ``PYTHONPATH`` in ``os.environ`` to include
|
||||
multiple local paths. A helper class :obj:`transformers.test_utils.TestCasePlus` comes to help:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
class EnvExampleTest(TestCasePlus):
|
||||
def test_external_prog(self):
|
||||
env = self.get_env()
|
||||
# now call the external program, passing ``env`` to it
|
||||
|
||||
Depending on whether the test file was under the ``tests`` test suite or ``examples`` it'll correctly set up
|
||||
``env[PYTHONPATH]`` to include one of these two directories, and also the ``src`` directory to ensure the testing is
|
||||
done against the current repo, and finally with whatever ``env[PYTHONPATH]`` was already set to before the test was
|
||||
called if anything.
|
||||
|
||||
This helper method creates a copy of the ``os.environ`` object, so the original remains intact.
|
||||
|
||||
|
||||
Getting reproducible results
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, is_torch_available
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.testing_utils import TestCasePlus, slow
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, slow
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
from .utils import execute_async_std
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -166,11 +162,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
||||
|
||||
# XXX: remove hardcoded path
|
||||
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
argv = f"""
|
||||
args = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--data_dir {data_dir}
|
||||
--output_dir {output_dir}
|
||||
@@ -204,31 +198,16 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
if n_gpu > 1:
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
cur_path = path.parents[0]
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
distributed_args = (
|
||||
f"-m torch.distributed.launch --nproc_per_node={n_gpu} {cur_path}/finetune_trainer.py".split()
|
||||
)
|
||||
cmd = [sys.executable] + distributed_args + argv
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
if result.returncode > 0:
|
||||
pytest.fail(f"failed with returncode {result.returncode}")
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
{self.test_file_dir}/finetune_trainer.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
else:
|
||||
# 0 or 1 gpu
|
||||
testargs = ["finetune_trainer.py"] + argv
|
||||
testargs = ["finetune_trainer.py"] + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
main()
|
||||
|
||||
|
||||
@@ -5,12 +5,10 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
||||
|
||||
from .utils import execute_async_std, load_json
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -157,23 +155,9 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
return f"--{k}"
|
||||
return f"--{k}={v}"
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
cur_path = path.parents[0]
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
if result.returncode > 0:
|
||||
pytest.fail(f"failed with returncode {result.returncode}")
|
||||
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import sys
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
@@ -641,74 +640,6 @@ def check_output_dir(args, expected_items=0):
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({args.output_dir}) already exists and "
|
||||
"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
||||
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
|
||||
# the following code deals with async io between processes
|
||||
|
||||
# adapted from https://stackoverflow.com/a/59041913/9201239
|
||||
import asyncio # noqa
|
||||
|
||||
|
||||
class _RunOutput:
|
||||
def __init__(self, returncode, stdout, stderr):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
async def _read_stream(stream, callback):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
callback(line)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
||||
if echo:
|
||||
print(cmd)
|
||||
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdin=stdin,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
out = []
|
||||
err = []
|
||||
|
||||
def tee(line, sink, pipe, label=""):
|
||||
line = line.decode("utf-8").rstrip()
|
||||
sink.append(line)
|
||||
if not quiet:
|
||||
print(label, line, file=pipe)
|
||||
|
||||
await asyncio.wait(
|
||||
[
|
||||
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
||||
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
||||
],
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# XXX: warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
||||
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
||||
#
|
||||
# If it starts hanging, will need to switch s/wait/communicate/ - so perhaps for debug we will enable
|
||||
# `wait` as it's easier to see in real time, but for normal runs use `communicate`
|
||||
return _RunOutput(await p.wait(), out, err)
|
||||
|
||||
|
||||
def execute_async_std(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(
|
||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -476,7 +476,31 @@ class TestCasePlus(unittest.TestCase):
|
||||
"""
|
||||
This class extends `unittest.TestCase` with additional features.
|
||||
|
||||
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get removed at the end of test.
|
||||
Feature 1: A set of fully resolved important file and dir path accessors.
|
||||
|
||||
In tests often we need to know where things are relative to the current test file, and it's not trivial since the
|
||||
test could be invoked from more than one directory or could reside in sub-directories with different depths. This
|
||||
class solves this problem by sorting out all the basic paths and provides easy accessors to them:
|
||||
|
||||
* ``pathlib`` objects (all fully resolved):
|
||||
|
||||
- ``test_file_path`` - the current test file path (=``__file__``)
|
||||
- ``test_file_dir`` - the directory containing the current test file
|
||||
- ``tests_dir`` - the directory of the ``tests`` test suite
|
||||
- ``examples_dir`` - the directory of the ``examples`` test suite
|
||||
- ``repo_root_dir`` - the directory of the repository
|
||||
- ``src_dir`` - the directory of ``src`` (i.e. where the ``transformers`` sub-dir resides)
|
||||
|
||||
* stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects:
|
||||
|
||||
- ``test_file_path_str``
|
||||
- ``test_file_dir_str``
|
||||
- ``tests_dir_str``
|
||||
- ``examples_dir_str``
|
||||
- ``repo_root_dir_str``
|
||||
- ``src_dir_str``
|
||||
|
||||
Feature 2: Flexible auto-removable temp dirs which are guaranteed to get removed at the end of test.
|
||||
|
||||
In all the following scenarios the temp dir will be auto-removed at the end of test, unless `after=False`.
|
||||
|
||||
@@ -499,7 +523,6 @@ class TestCasePlus(unittest.TestCase):
|
||||
temp results
|
||||
|
||||
::
|
||||
|
||||
def test_whatever(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
|
||||
|
||||
@@ -517,11 +540,104 @@ class TestCasePlus(unittest.TestCase):
|
||||
|
||||
Note 2: Each test can register multiple temp dirs and they all will get auto-removed, unless requested otherwise.
|
||||
|
||||
Feature 3: Get a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` specific to the current test suite.
|
||||
This is useful for invoking external programs from the test suite - e.g. distributed training.
|
||||
|
||||
|
||||
::
|
||||
def test_whatever(self):
|
||||
env = self.get_env()
|
||||
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.teardown_tmp_dirs = []
|
||||
|
||||
# figure out the resolved paths for repo_root, tests, examples, etc.
|
||||
self._test_file_path = inspect.getfile(self.__class__)
|
||||
path = Path(self._test_file_path).resolve()
|
||||
self._test_file_dir = path.parents[0]
|
||||
for up in [1, 2, 3]:
|
||||
tmp_dir = path.parents[up]
|
||||
if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
|
||||
break
|
||||
if tmp_dir:
|
||||
self._repo_root_dir = tmp_dir
|
||||
else:
|
||||
raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
|
||||
self._tests_dir = self._repo_root_dir / "tests"
|
||||
self._examples_dir = self._repo_root_dir / "examples"
|
||||
self._src_dir = self._repo_root_dir / "src"
|
||||
|
||||
@property
|
||||
def test_file_path(self):
|
||||
return self._test_file_path
|
||||
|
||||
@property
|
||||
def test_file_path_str(self):
|
||||
return str(self._test_file_path)
|
||||
|
||||
@property
|
||||
def test_file_dir(self):
|
||||
return self._test_file_dir
|
||||
|
||||
@property
|
||||
def test_file_dir_str(self):
|
||||
return str(self._test_file_dir)
|
||||
|
||||
@property
|
||||
def tests_dir(self):
|
||||
return self._tests_dir
|
||||
|
||||
@property
|
||||
def tests_dir_str(self):
|
||||
return str(self._tests_dir)
|
||||
|
||||
@property
|
||||
def examples_dir(self):
|
||||
return self._examples_dir
|
||||
|
||||
@property
|
||||
def examples_dir_str(self):
|
||||
return str(self._examples_dir)
|
||||
|
||||
@property
|
||||
def repo_root_dir(self):
|
||||
return self._repo_root_dir
|
||||
|
||||
@property
|
||||
def repo_root_dir_str(self):
|
||||
return str(self._repo_root_dir)
|
||||
|
||||
@property
|
||||
def src_dir(self):
|
||||
return self._src_dir
|
||||
|
||||
@property
|
||||
def src_dir_str(self):
|
||||
return str(self._src_dir)
|
||||
|
||||
def get_env(self):
|
||||
"""
|
||||
Return a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` correctly, depending on the test suite
|
||||
it's invoked from. This is useful for invoking external programs from the test suite - e.g. distributed
|
||||
training.
|
||||
|
||||
It always inserts ``./src`` first, then ``./tests`` or ``./examples`` depending on the test suite type and
|
||||
finally the preset ``PYTHONPATH`` if any (all full resolved paths).
|
||||
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
paths = [self.src_dir_str]
|
||||
if "/examples" in self.test_file_dir_str:
|
||||
paths.append(self.examples_dir_str)
|
||||
else:
|
||||
paths.append(self.tests_dir_str)
|
||||
paths.append(env.get("PYTHONPATH", ""))
|
||||
|
||||
env["PYTHONPATH"] = ":".join(paths)
|
||||
return env
|
||||
|
||||
def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
|
||||
"""
|
||||
Args:
|
||||
@@ -676,3 +792,84 @@ def pytest_terminal_summary_main(tr, id):
|
||||
tr._tw = orig_writer
|
||||
tr.reportchars = orig_reportchars
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
|
||||
# the following code deals with async io between processes
|
||||
|
||||
# adapted from https://stackoverflow.com/a/59041913/9201239
|
||||
import asyncio # noqa
|
||||
|
||||
|
||||
class _RunOutput:
|
||||
def __init__(self, returncode, stdout, stderr):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
async def _read_stream(stream, callback):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if line:
|
||||
callback(line)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
|
||||
if echo:
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
cmd[0],
|
||||
*cmd[1:],
|
||||
stdin=stdin,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
|
||||
# note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
|
||||
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
|
||||
#
|
||||
# If it starts hanging, will need to switch to the following code. The problem is that no data
|
||||
# will be seen until it's done and if it hangs for example there will be no debug info.
|
||||
# out, err = await p.communicate()
|
||||
# return _RunOutput(p.returncode, out, err)
|
||||
|
||||
out = []
|
||||
err = []
|
||||
|
||||
def tee(line, sink, pipe, label=""):
|
||||
line = line.decode("utf-8").rstrip()
|
||||
sink.append(line)
|
||||
if not quiet:
|
||||
print(label, line, file=pipe)
|
||||
|
||||
# XXX: the timeout doesn't seem to make any difference here
|
||||
await asyncio.wait(
|
||||
[
|
||||
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout)),
|
||||
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
|
||||
],
|
||||
timeout=timeout,
|
||||
)
|
||||
return _RunOutput(await p.wait(), out, err)
|
||||
|
||||
|
||||
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(
|
||||
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
|
||||
)
|
||||
|
||||
cmd_str = " ".join(cmd)
|
||||
if result.returncode > 0:
|
||||
raise RuntimeError(
|
||||
f"'{cmd_str}' failed with returncode {result.returncode} - see the `stderr:` messages from above for details."
|
||||
)
|
||||
if not result.stdout:
|
||||
raise RuntimeError(f"'{cmd_str}' produced no output.")
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,22 +1,8 @@
|
||||
# This test is meant to be run in torch.distributed,
|
||||
# on a machine with multiple GPUs, in the following way:
|
||||
#
|
||||
# python -m torch.distributed.launch --nproc_per_node 2 ./tests/test_trainer_distributed.py
|
||||
#
|
||||
# Replace 2 with the number of GPUs you have.
|
||||
#
|
||||
# You can also run it as a standalone file to test identical behavior in nn.DataParallel:
|
||||
# python ./tests/test_trainer_distributed.py
|
||||
# and in single-GPU mode:
|
||||
# CUDA_VISIBLE_DEVICES=0 python ./tests/test_trainer_distributed.py
|
||||
# and in CPU mode:
|
||||
# CUDA_VISIBLE_DEVICES=-1 python ./tests/test_trainer_distributed.py
|
||||
#
|
||||
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -57,9 +43,28 @@ if is_torch_available():
|
||||
return input_ids
|
||||
|
||||
|
||||
class TestTrainerDistributed(TestCasePlus):
|
||||
@require_torch_multigpu
|
||||
def test_trainer(self):
|
||||
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
{self.test_file_dir}/test_trainer_distributed.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"--output_dir {output_dir}".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
#
|
||||
# PYTHONPATH="src" python -m torch.distributed.launch --nproc_per_node 2 --output_dir output_dir ./tests/test_trainer_distributed.py
|
||||
|
||||
parser = HfArgumentParser((TrainingArguments,))
|
||||
sys.argv += ["--output_dir", "./examples"]
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
logger.warning(
|
||||
@@ -70,9 +75,8 @@ if __name__ == "__main__":
|
||||
training_args.local_rank != -1,
|
||||
)
|
||||
|
||||
# Essentially, what we want to verify in the distributed case is
|
||||
# that we get all samples back, in the right order.
|
||||
# (this is crucial for prediction for instance)
|
||||
# Essentially, what we want to verify in the distributed case is that we get all samples back,
|
||||
# in the right order. (this is crucial for prediction for instance)
|
||||
for dataset_length in [101, 40, 7]:
|
||||
dataset = DummyDataset(dataset_length)
|
||||
|
||||
@@ -115,5 +119,3 @@ if __name__ == "__main__":
|
||||
exit(1)
|
||||
|
||||
trainer.args.eval_accumulation_steps = None
|
||||
|
||||
logger.info("🔥 All distributed tests successful")
|
||||
|
||||
Reference in New Issue
Block a user