add tests for the new sharded ddp fairscale integration (#9177)
This commit is contained in:
@@ -14,10 +14,12 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel
|
from transformers import BertTokenizer, EncoderDecoderModel
|
||||||
from transformers.file_utils import is_datasets_available
|
from transformers.file_utils import is_datasets_available
|
||||||
|
from transformers.integrations import is_fairscale_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
@@ -38,9 +40,20 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
|||||||
MBART_TINY = "sshleifer/tiny-mbart"
|
MBART_TINY = "sshleifer/tiny-mbart"
|
||||||
|
|
||||||
|
|
||||||
|
# a candidate for testing_utils
|
||||||
|
def require_fairscale(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires fairscale
|
||||||
|
"""
|
||||||
|
if not is_fairscale_available():
|
||||||
|
return unittest.skip("test requires fairscale")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
class TestFinetuneTrainer(TestCasePlus):
|
class TestFinetuneTrainer(TestCasePlus):
|
||||||
def finetune_trainer_quick(self, distributed=None):
|
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
|
||||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
|
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
|
||||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||||
first_step_stats = eval_metrics[0]
|
first_step_stats = eval_metrics[0]
|
||||||
@@ -59,6 +72,16 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
def test_finetune_trainer_ddp(self):
|
def test_finetune_trainer_ddp(self):
|
||||||
self.finetune_trainer_quick(distributed=True)
|
self.finetune_trainer_quick(distributed=True)
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_fairscale
|
||||||
|
def test_finetune_trainer_ddp_sharded_ddp(self):
|
||||||
|
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")
|
||||||
|
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@require_fairscale
|
||||||
|
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
||||||
|
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_finetune_trainer_slow(self):
|
def test_finetune_trainer_slow(self):
|
||||||
# There is a missing call to __init__process_group somewhere
|
# There is a missing call to __init__process_group somewhere
|
||||||
@@ -195,7 +218,13 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
def run_trainer(
|
def run_trainer(
|
||||||
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
|
self,
|
||||||
|
eval_steps: int,
|
||||||
|
max_len: str,
|
||||||
|
model_name: str,
|
||||||
|
num_train_epochs: int,
|
||||||
|
distributed: bool = False,
|
||||||
|
extra_args_str: str = None,
|
||||||
):
|
):
|
||||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
@@ -231,6 +260,9 @@ class TestFinetuneTrainer(TestCasePlus):
|
|||||||
""".split()
|
""".split()
|
||||||
# --eval_beams 2
|
# --eval_beams 2
|
||||||
|
|
||||||
|
if extra_args_str is not None:
|
||||||
|
args.extend(extra_args_str.split())
|
||||||
|
|
||||||
if distributed:
|
if distributed:
|
||||||
n_gpu = get_gpu_count()
|
n_gpu = get_gpu_count()
|
||||||
distributed_args = f"""
|
distributed_args = f"""
|
||||||
|
|||||||
Reference in New Issue
Block a user