Fix pad across processes dim in trainer and not being able to set the timeout (#24775)
* dim, and rm copy * Don't rm copy for now * Oops * pad index * Should be a working test * Tickle down ddp timeout * Put fix back in now that testing locally is done * Better comment specifying timeout Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
||||
USER,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
@@ -2098,6 +2099,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
def test_end_to_end_example(self):
|
||||
# Tests that `translation.py` will run without issues
|
||||
script_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "examples", "pytorch", "translation", "run_translation.py"
|
||||
)
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
command = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
script_path,
|
||||
"--model_name_or_path",
|
||||
"t5-small",
|
||||
"--per_device_train_batch_size",
|
||||
"1",
|
||||
"--output_dir",
|
||||
tmpdir,
|
||||
"--overwrite_output_dir",
|
||||
"--do_train",
|
||||
"--max_train_samples",
|
||||
"64",
|
||||
"--num_train_epochs",
|
||||
"1",
|
||||
"--dataset_name",
|
||||
"wmt16",
|
||||
"--dataset_config",
|
||||
"ro-en",
|
||||
"--source_lang",
|
||||
"en",
|
||||
"--target_lang",
|
||||
"ro",
|
||||
"--do_predict",
|
||||
"--max_predict_samples",
|
||||
"64",
|
||||
"--predict_with_generate",
|
||||
"--ddp_timeout",
|
||||
"60",
|
||||
]
|
||||
execute_subprocess_async(command)
|
||||
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user