accelerate deepspeed and gradient accumulation integrate (#23236)

* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving

* shift torch dynamo handling to accelerate

* shift deepspeed integration and save & load utils to accelerate

* fix accelerate launcher support

* oops

* fix 🐛

* save ckpt fix

* Trigger CI

* nasty 🐛 😅

* as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate

* make tests happy

* quality 

* loss tracked needs to account for grad_acc

* fixing the deepspeed tests

* quality 

* 😅😅😅

* tests 😡

* quality 

* Trigger CI

* resolve comments and fix the issue with the previous merge from branch

* Trigger CI

* accelerate took over deepspeed integration

---------

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Sourab Mangrulkar
2023-05-31 15:16:22 +05:30
committed by GitHub
parent 88f50a1e89
commit a73b1d59a3
8 changed files with 164 additions and 161 deletions

View File

@@ -64,7 +64,7 @@ if is_torch_available():
import torch.distributed as dist
if is_accelerate_available():
from accelerate import PartialState
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
if is_torch_tpu_available(check_device=False):
@@ -1550,6 +1550,7 @@ class TrainingArguments:
if isinstance(self.debug, str):
self.debug = [DebugOption(s) for s in self.debug.split()]
self.deepspeed_plugin = None
if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created.
@@ -1562,6 +1563,12 @@ class TrainingArguments:
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)
# Accelerate DeepSpeed Plugin
from accelerate.utils import DeepSpeedPlugin
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
@@ -1660,6 +1667,8 @@ class TrainingArguments:
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
AcceleratorState._reset_state()
PartialState._reset_state()
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"