[s2s trainer] fix DP mode (#8823)
* fix DP case on multi-gpu * make executable * test all 3 modes * use the correct check for distributed * dp doesn't need a special case * restore original name * cleanup
This commit is contained in:
2
examples/seq2seq/finetune_trainer.py
Normal file → Executable file
2
examples/seq2seq/finetune_trainer.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@@ -122,7 +122,8 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
if self.args.sortish_sampler:
|
if self.args.sortish_sampler:
|
||||||
self.train_dataset.make_sortish_sampler(
|
self.train_dataset.make_sortish_sampler(
|
||||||
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
|
self.args.per_device_train_batch_size,
|
||||||
|
distributed=(self.args.local_rank != -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -4,7 +4,14 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel
|
from transformers import BertTokenizer, EncoderDecoderModel
|
||||||
from transformers.file_utils import is_datasets_available
|
from transformers.file_utils import is_datasets_available
|
||||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
|
from transformers.testing_utils import (
|
||||||
|
TestCasePlus,
|
||||||
|
execute_subprocess_async,
|
||||||
|
get_gpu_count,
|
||||||
|
require_torch_multi_gpu,
|
||||||
|
require_torch_non_multi_gpu,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import set_seed
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
@@ -18,17 +25,32 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
|||||||
|
|
||||||
|
|
||||||
class TestFinetuneTrainer(TestCasePlus):
|
class TestFinetuneTrainer(TestCasePlus):
|
||||||
def test_finetune_trainer(self):
|
def finetune_trainer_quick(self, distributed=None):
|
||||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
assert "eval_bleu" in first_step_stats
|
assert "eval_bleu" in first_step_stats
|
||||||
|
|
||||||
|
@require_torch_non_multi_gpu
|
||||||
|
def test_finetune_trainer_no_dist(self):
|
||||||
|
self.finetune_trainer_quick()
|
||||||
|
|
||||||
|
# the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_finetune_trainer_dp(self):
|
||||||
|
self.finetune_trainer_quick(distributed=False)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_finetune_trainer_ddp(self):
|
||||||
|
self.finetune_trainer_quick(distributed=True)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow(self):
|
def test_finetune_trainer_slow(self):
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
|
output_dir = self.run_trainer(
|
||||||
|
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
|
||||||
|
)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
@@ -158,7 +180,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
# start training
|
# start training
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
|
def run_trainer(
|
||||||
|
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
|
||||||
|
):
|
||||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
args = f"""
|
args = f"""
|
||||||
@@ -193,8 +217,8 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
""".split()
|
""".split()
|
||||||
# --eval_beams 2
|
# --eval_beams 2
|
||||||
|
|
||||||
n_gpu = get_gpu_count()
|
if distributed:
|
||||||
if n_gpu > 1:
|
n_gpu = get_gpu_count()
|
||||||
distributed_args = f"""
|
distributed_args = f"""
|
||||||
-m torch.distributed.launch
|
-m torch.distributed.launch
|
||||||
--nproc_per_node={n_gpu}
|
--nproc_per_node={n_gpu}
|
||||||
@@ -203,7 +227,6 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
cmd = [sys.executable] + distributed_args + args
|
cmd = [sys.executable] + distributed_args + args
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
else:
|
else:
|
||||||
# 0 or 1 gpu
|
|
||||||
testargs = ["finetune_trainer.py"] + args
|
testargs = ["finetune_trainer.py"] + args
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user