[Deepspeed] support zero.Init in from_config (#11805)
* support zero.Init in from_config * no need for eval test
This commit is contained in:
@@ -18,9 +18,14 @@ import types
|
|||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...file_utils import copy_func
|
from ...file_utils import copy_func
|
||||||
|
from ...integrations import deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
|
from ...utils import logging
|
||||||
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
CLASS_DOCSTRING = """
|
CLASS_DOCSTRING = """
|
||||||
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
This is a generic model class that will be instantiated as one of the model classes of the library when created
|
||||||
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
|
with the :meth:`~transformers.BaseAutoModelClass.from_pretrained` class method or the
|
||||||
@@ -362,6 +367,15 @@ class _BaseAutoModelClass:
|
|||||||
def from_config(cls, config, **kwargs):
|
def from_config(cls, config, **kwargs):
|
||||||
if type(config) in cls._model_mapping.keys():
|
if type(config) in cls._model_mapping.keys():
|
||||||
model_class = _get_model_class(config, cls._model_mapping)
|
model_class = _get_model_class(config, cls._model_mapping)
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||||
|
# this immediately partitions the model across all gpus, to avoid the overhead in time
|
||||||
|
# and memory copying it on CPU or each GPU first
|
||||||
|
with deepspeed.zero.Init(config=deepspeed_config()):
|
||||||
|
return model_class(config, **kwargs)
|
||||||
|
else:
|
||||||
return model_class(config, **kwargs)
|
return model_class(config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from transformers.file_utils import WEIGHTS_NAME
|
|||||||
from transformers.integrations import is_deepspeed_available
|
from transformers.integrations import is_deepspeed_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
CaptureStderr,
|
||||||
ExtendSysPath,
|
ExtendSysPath,
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
@@ -741,7 +742,38 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
return output_dir
|
def test_clm_from_config_zero3(self):
|
||||||
|
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
|
||||||
|
|
||||||
|
data_dir = self.tests_dir / "fixtures"
|
||||||
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
args = f"""
|
||||||
|
--model_type gpt2
|
||||||
|
--tokenizer_name sshleifer/tiny-gpt2
|
||||||
|
--train_file {data_dir}/sample_text.txt
|
||||||
|
--validation_file {data_dir}/sample_text.txt
|
||||||
|
--output_dir {output_dir}
|
||||||
|
--overwrite_output_dir
|
||||||
|
--do_train
|
||||||
|
--max_train_samples 4
|
||||||
|
--per_device_train_batch_size 2
|
||||||
|
--num_train_epochs 1
|
||||||
|
--warmup_steps 8
|
||||||
|
--block_size 8
|
||||||
|
--fp16
|
||||||
|
--report_to none
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_zero3.json".split()
|
||||||
|
script = [f"{self.examples_dir_str}/pytorch/language-modeling/run_clm.py"]
|
||||||
|
launcher = self.get_launcher(distributed=True)
|
||||||
|
|
||||||
|
cmd = launcher + script + args + ds_args
|
||||||
|
# keep for quick debug
|
||||||
|
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||||
|
with CaptureStderr() as cs:
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
assert "Detected DeepSpeed ZeRO-3" in cs.err
|
||||||
|
|
||||||
def get_launcher(self, distributed=False):
|
def get_launcher(self, distributed=False):
|
||||||
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
|
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
|
||||||
|
|||||||
Reference in New Issue
Block a user