From 0b774074a5f7f2137d0f1743bb9990cfb7e7a1d8 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 31 May 2023 14:10:46 +0530 Subject: [PATCH] move fsdp handling to accelerate (#23158) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * move ddp prep to accelerate * fix 😅 * resolving comments * move fsdp handling to accelerate * fixex * fix saving --- src/transformers/trainer.py | 193 ++++++++++++------------------ src/transformers/training_args.py | 28 ++++- 2 files changed, 105 insertions(+), 116 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 567fe0ac84..491d37cbc2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -343,6 +343,12 @@ class Trainer: # create accelerator object self.accelerator = Accelerator() + # post accelerator creation setup + if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) + fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) + # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() @@ -464,7 +470,7 @@ class Trainer: self.fsdp = ShardingStrategy.NO_SHARD self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE - if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get( + if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( "backward_prefetch", [] ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST @@ -1479,114 +1485,58 @@ class Trainer: cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP - elif self.fsdp is not None: - if not self.args.fsdp_config["xla"]: - # PyTorch FSDP! - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy - - if FSDPOption.OFFLOAD in self.args.fsdp: - cpu_offload = CPUOffload(offload_params=True) - else: - cpu_offload = CPUOffload(offload_params=False) - - auto_wrap_policy = None - - if FSDPOption.AUTO_WRAP in self.args.fsdp: - if self.args.fsdp_config["fsdp_min_num_params"] > 0: - auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] - ) - elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - transformer_cls_to_wrap = set() - for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: - transformer_cls = get_module_class_from_name(model, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - # Transformer layer class to wrap - transformer_layer_cls=transformer_cls_to_wrap, - ) - mixed_precision_policy = None - dtype = None - if self.args.fp16: - dtype = torch.float16 - elif self.args.bf16: - dtype = torch.bfloat16 - if dtype is not None: - mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) - if type(model) != FSDP: - # XXX: Breaking the self.model convention but I see no way around it for now. - signature = inspect.signature(FSDP.__init__).parameters.keys() - kwargs = {} - for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]: - if arg in signature: - kwargs[arg] = getattr(self, arg) - self.model = model = FSDP( - model, - sharding_strategy=self.fsdp, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - mixed_precision=mixed_precision_policy, - device_id=self.args.device, - **kwargs, - ) - else: - try: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP - from torch_xla.distributed.fsdp import checkpoint_module - from torch_xla.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - transformer_auto_wrap_policy, - ) - except ImportError: - raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") - auto_wrap_policy = None - auto_wrapper_callable = None - if self.args.fsdp_config["fsdp_min_num_params"] > 0: - auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] - ) - elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - transformer_cls_to_wrap = set() - for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: - transformer_cls = get_module_class_from_name(model, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - # Transformer layer class to wrap - transformer_layer_cls=transformer_cls_to_wrap, - ) - fsdp_kwargs = self.args.xla_fsdp_config - if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: - # Apply gradient checkpointing to auto-wrapped sub-modules if specified - def auto_wrapper_callable(m, *args, **kwargs): - return FSDP(checkpoint_module(m), *args, **kwargs) - - # Wrap the base model with an outer FSDP wrapper - self.model = model = FSDP( - model, - auto_wrap_policy=auto_wrap_policy, - auto_wrapper_callable=auto_wrapper_callable, - **fsdp_kwargs, + elif self.fsdp is not None and self.args.fsdp_config["xla"]: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) - # Patch `xm.optimizer_step` should not reduce gradients in this case, - # as FSDP does not need gradient reduction over sharded parameters. - def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): - loss = optimizer.step(**optimizer_args) - if barrier: - xm.mark_step() - return loss + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) - xm.optimizer_step = patched_optimizer_step + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] @@ -1796,17 +1746,26 @@ class Trainer: if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare - model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.lr_scheduler - ) + if use_accelerator_prepare: + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -2894,11 +2853,15 @@ class Trainer: ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None + or getattr(self.accelerator.state, "fsdp_plugin", None) is not None ): - state_dict = self.model.state_dict() + if getattr(self.accelerator.state, "fsdp_plugin", None) is not None: + self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir) + else: + state_dict = self.model.state_dict() - if self.args.should_save: - self._save(output_dir, state_dict=state_dict) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) elif self.deepspeed: # this takes care of everything as long as we aren't under zero3 if self.args.should_save: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index a0b11cd12e..001c8c5eeb 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -442,7 +442,7 @@ class TrainingArguments: - `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's gradient computation. - - `"backward_pos"` : This prefetches the next set of parameters after the current set of + - `"backward_post"` : This prefetches the next set of parameters after the current set of parameter’s gradient computation. - fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`) @@ -1504,6 +1504,32 @@ class TrainingArguments: if self.fsdp_config["xla_fsdp_grad_ckpt"]: warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.") + # accelerate integration for FSDP + if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + os.environ["ACCELERATE_USE_FSDP"] = "true" + from accelerate.utils.constants import ( + FSDP_AUTO_WRAP_POLICY, + FSDP_SHARDING_STRATEGY, + ) + + for fsdp_option in self.fsdp: + if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: + # set environment variable for FSDP sharding strategy + os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) + elif fsdp_option == FSDPOption.OFFLOAD: + os.environ["FSDP_OFFLOAD_PARAMS"] = "true" + elif fsdp_option == FSDPOption.AUTO_WRAP: + if self.fsdp_config["fsdp_min_num_params"] > 0: + os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) + os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] + ) + os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") + os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() + if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"