[testing] port test_trainer_distributed to distributed pytest + TestCasePlus enhancements (#8107)
* move the helper code into testing_utils * port test_trainer_distributed to work with pytest * improve docs * simplify notes * doc * doc * style * doc * further improvements * torch might not be available * real fix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -5,12 +5,10 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_multigpu
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multigpu
|
||||
|
||||
from .utils import execute_async_std, load_json
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -157,23 +155,9 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
return f"--{k}"
|
||||
return f"--{k}={v}"
|
||||
|
||||
path = Path(__file__).resolve()
|
||||
cur_path = path.parents[0]
|
||||
examples_path = path.parents[1]
|
||||
src_path = f"{path.parents[2]}/src"
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{examples_path}:{src_path}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, f"{cur_path}/distillation.py"] + cli_args
|
||||
|
||||
print("\nRunning: ", " ".join(cmd))
|
||||
|
||||
result = execute_async_std(cmd, env=env, stdin=None, timeout=180, quiet=False, echo=False)
|
||||
|
||||
assert result.stdout, "produced no output"
|
||||
if result.returncode > 0:
|
||||
pytest.fail(f"failed with returncode {result.returncode}")
|
||||
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
|
||||
Reference in New Issue
Block a user