[Deepspeed] adapt multiple models, add zero_to_fp32 tests (#12477)

* zero_to_fp32 tests

* args change

* remove unnecessary work

* use transformers.trainer_utils.get_last_checkpoint

* document the new features

* cleanup

* wip

* fix fsmt

* add bert

* cleanup

* add xlm-roberta

* electra works

* cleanup

* sync

* split off the model zoo tests

* cleanup

* cleanup

* cleanup

* cleanup

* reformat

* cleanup

* casing

* deepspeed>=0.4.3

* adjust distilbert

* Update docs/source/main_classes/deepspeed.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-07-13 12:07:32 -07:00
committed by GitHub
parent 65bf05cd18
commit 78f5fe1416
10 changed files with 444 additions and 80 deletions

View File

@@ -37,11 +37,12 @@ from transformers.testing_utils import (
require_torch_multi_gpu,
slow,
)
from transformers.trainer_utils import set_seed
from transformers.trainer_utils import get_last_checkpoint, set_seed
bindir = os.path.abspath(os.path.dirname(__file__))
with ExtendSysPath(f"{bindir}/.."):
tests_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
root_dir = os.path.dirname(tests_dir)
with ExtendSysPath(tests_dir):
from test_trainer import TrainerIntegrationCommon # noqa
if is_torch_available():
@@ -49,9 +50,10 @@ with ExtendSysPath(f"{bindir}/.."):
set_seed(42)
MBART_TINY = "sshleifer/tiny-mbart"
T5_SMALL = "t5-small"
T5_TINY = "patrickvonplaten/t5-tiny-random"
GPT2_TINY = "sshleifer/tiny-gpt2"
def load_json(path):
@@ -77,8 +79,19 @@ def require_deepspeed_aio(test_case):
if is_deepspeed_available():
from deepspeed.utils import logger as deepspeed_logger # noqa
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa
def get_launcher(distributed=False):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1
return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split()
ZERO2 = "zero2"
ZERO3 = "zero3"
stages = [ZERO2, ZERO3]
@@ -568,6 +581,41 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
@parameterized.expand(stages)
def test_load_state_dict_from_zero_checkpoint(self, stage):
# test that we can load fp32 weights directly from the zero checkpoint into the current model
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)
ds_config_dict = self.get_config_dict(stage)
kwargs = dict(
output_dir=output_dir,
train_len=4,
per_device_train_batch_size=4,
num_train_epochs=1,
save_strategy="steps",
save_steps=1,
learning_rate=0.1,
fp16=True,
deepspeed=ds_config_dict,
)
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()
state = dataclasses.asdict(trainer.state)
checkpoint_dir = get_last_checkpoint(output_dir)
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
(a1, b1) = model.a.item(), model.b.item()
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.check_trainer_state_are_the_same(state, state1)
def test_config_object(self):
# test that we can switch from zero2 to zero3 in the same process for example
# test is_zero, etc.
@@ -809,7 +857,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = self.get_launcher(distributed)
launcher = get_launcher(distributed)
cmd = launcher + script + args + ds_args
# keep for quick debug
@@ -826,7 +874,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
data_dir = self.tests_dir / "fixtures"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path sshleifer/tiny-gpt2
--model_name_or_path {GPT2_TINY}
--train_file {data_dir}/sample_text.txt
--validation_file {data_dir}/sample_text.txt
--output_dir {output_dir}
@@ -846,7 +894,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
launcher = self.get_launcher(distributed=True)
launcher = get_launcher(distributed=True)
cmd = launcher + script + args + ds_args
# keep for quick debug
@@ -860,7 +908,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_type gpt2
--tokenizer_name sshleifer/tiny-gpt2
--tokenizer_name {GPT2_TINY}
--train_file {data_dir}/sample_text.txt
--validation_file {data_dir}/sample_text.txt
--output_dir {output_dir}
@@ -877,7 +925,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
launcher = self.get_launcher(distributed=True)
launcher = get_launcher(distributed=True)
cmd = launcher + script + args + ds_args
# keep for quick debug
@@ -885,11 +933,3 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
assert "Detected DeepSpeed ZeRO-3" in cs.err
def get_launcher(self, distributed=False):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1
return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split()