From 63841c559b789abb72565bba6f69e0b3260e54f5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 17 Dec 2020 14:24:03 -0800 Subject: [PATCH] add tests for the new sharded ddp fairscale integration (#9177) --- examples/seq2seq/test_finetune_trainer.py | 38 +++++++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 92bad878aa..cf16e69a1b 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -14,10 +14,12 @@ import os import sys +import unittest from unittest.mock import patch from transformers import BertTokenizer, EncoderDecoderModel from transformers.file_utils import is_datasets_available +from transformers.integrations import is_fairscale_available from transformers.testing_utils import ( TestCasePlus, execute_subprocess_async, @@ -38,9 +40,20 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MBART_TINY = "sshleifer/tiny-mbart" +# a candidate for testing_utils +def require_fairscale(test_case): + """ + Decorator marking a test that requires fairscale + """ + if not is_fairscale_available(): + return unittest.skip("test requires fairscale")(test_case) + else: + return test_case + + class TestFinetuneTrainer(TestCasePlus): - def finetune_trainer_quick(self, distributed=None): - output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed) + def finetune_trainer_quick(self, distributed=None, extra_args_str=None): + output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history eval_metrics = [log for log in logs if "eval_loss" in log.keys()] first_step_stats = eval_metrics[0] @@ -59,6 +72,16 @@ class TestFinetuneTrainer(TestCasePlus): def test_finetune_trainer_ddp(self): self.finetune_trainer_quick(distributed=True) + @require_torch_multi_gpu + @require_fairscale + def test_finetune_trainer_ddp_sharded_ddp(self): + self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp") + + @require_torch_multi_gpu + @require_fairscale + def test_finetune_trainer_ddp_sharded_ddp_fp16(self): + self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16") + @slow def test_finetune_trainer_slow(self): # There is a missing call to __init__process_group somewhere @@ -195,7 +218,13 @@ class TestFinetuneTrainer(TestCasePlus): trainer.train() def run_trainer( - self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False + self, + eval_steps: int, + max_len: str, + model_name: str, + num_train_epochs: int, + distributed: bool = False, + extra_args_str: str = None, ): data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() @@ -231,6 +260,9 @@ class TestFinetuneTrainer(TestCasePlus): """.split() # --eval_beams 2 + if extra_args_str is not None: + args.extend(extra_args_str.split()) + if distributed: n_gpu = get_gpu_count() distributed_args = f"""