accelerate deepspeed and gradient accumulation integrate (#23236)
* mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * move ddp prep to accelerate * fix 😅 * resolving comments * move fsdp handling to accelerate * fixex * fix saving * shift torch dynamo handling to accelerate * shift deepspeed integration and save & load utils to accelerate * fix accelerate launcher support * oops * fix 🐛 * save ckpt fix * Trigger CI * nasty 🐛 😅 * as deepspeed needs grad_acc fixes, transfer grad_acc to accelerate * make tests happy * quality ✨ * loss tracked needs to account for grad_acc * fixing the deepspeed tests * quality ✨ * 😅😅😅 * tests 😡 * quality ✨ * Trigger CI * resolve comments and fix the issue with the previous merge from branch * Trigger CI * accelerate took over deepspeed integration --------- Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
committed by
GitHub
parent
88f50a1e89
commit
a73b1d59a3
@@ -64,7 +64,7 @@ if is_torch_available():
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import PartialState
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
from accelerate.utils import DistributedType
|
||||
|
||||
if is_torch_tpu_available(check_device=False):
|
||||
@@ -1550,6 +1550,7 @@ class TrainingArguments:
|
||||
if isinstance(self.debug, str):
|
||||
self.debug = [DebugOption(s) for s in self.debug.split()]
|
||||
|
||||
self.deepspeed_plugin = None
|
||||
if self.deepspeed:
|
||||
# - must be run very last in arg parsing, since it will use a lot of these settings.
|
||||
# - must be run before the model is created.
|
||||
@@ -1562,6 +1563,12 @@ class TrainingArguments:
|
||||
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
||||
self.hf_deepspeed_config.trainer_config_process(self)
|
||||
|
||||
# Accelerate DeepSpeed Plugin
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
|
||||
|
||||
if self.push_to_hub_token is not None:
|
||||
warnings.warn(
|
||||
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||
@@ -1660,6 +1667,8 @@ class TrainingArguments:
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
requires_backends(self, ["torch"])
|
||||
logger.info("PyTorch: setting up devices")
|
||||
AcceleratorState._reset_state()
|
||||
PartialState._reset_state()
|
||||
if not is_sagemaker_mp_enabled() and not is_accelerate_available(check_partial_state=True):
|
||||
raise ImportError(
|
||||
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||
|
||||
Reference in New Issue
Block a user