From 9f0646a5550ccfb49a139abe14edb724edf785f5 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 31 May 2023 12:27:51 +0530 Subject: [PATCH] Smangrul/accelerate mp integrate (#23148) * mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * address comments by removing debugging print statements --- src/transformers/trainer.py | 68 +++++++++++++++++++------------ src/transformers/training_args.py | 9 ++++ 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 357cfc45bd..842a410d15 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -212,6 +212,8 @@ if is_accelerate_available(): if version.parse(accelerate_version) >= version.parse("0.16"): from accelerate import skip_first_batches + from accelerate import Accelerator + if TYPE_CHECKING: import optuna @@ -337,6 +339,9 @@ class Trainer: self.deepspeed = None self.is_in_train = False + # create accelerator object + self.accelerator = Accelerator() + # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() @@ -607,7 +612,7 @@ class Trainer: "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) - if args.fp16 or args.bf16: + if (args.fp16 or args.bf16) and self.sharded_ddp is not None: if args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: @@ -624,30 +629,31 @@ class Trainer: self.do_grad_scaling = False if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision - if args.half_precision_backend == "cuda_amp": - self.use_cuda_amp = True - self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 - # bf16 does not need grad scaling - self.do_grad_scaling = self.amp_dtype == torch.float16 - if self.do_grad_scaling: - if self.sharded_ddp is not None: - self.scaler = ShardedGradScaler() - elif self.fsdp is not None: - from torch.distributed.fsdp.sharded_grad_scaler import ( - ShardedGradScaler as FSDPShardedGradScaler, - ) + if self.sharded_ddp is not None: + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: + if self.sharded_ddp is not None: + self.scaler = ShardedGradScaler() + elif self.fsdp is not None: + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) - self.scaler = FSDPShardedGradScaler() - elif is_torch_tpu_available(): - from torch_xla.amp import GradScaler + self.scaler = FSDPShardedGradScaler() + elif is_torch_tpu_available(): + from torch_xla.amp import GradScaler - self.scaler = GradScaler() - else: - self.scaler = torch.cuda.amp.GradScaler() - elif args.half_precision_backend == "cpu_amp": - self.use_cpu_amp = True - self.amp_dtype = torch.bfloat16 - else: + self.scaler = GradScaler() + else: + self.scaler = torch.cuda.amp.GradScaler() + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( "Using FP16 with APEX but APEX is not installed, please refer to" @@ -1801,6 +1807,11 @@ class Trainer: if delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # prepare using `accelerator` prepare + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -2013,10 +2024,15 @@ class Trainer: elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) - else: + elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), args.max_grad_norm, ) @@ -2802,7 +2818,7 @@ class Trainer: # loss gets scaled under gradient_accumulation_steps in deepspeed loss = self.deepspeed.backward(loss) else: - loss.backward() + self.accelerator.backward(loss) return loss.detach() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 63876e053a..a0b11cd12e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1562,6 +1562,15 @@ class TrainingArguments: FutureWarning, ) + # if training args is specified, it will override the one specified in the accelerate config + if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + def __str__(self): self_as_dict = asdict(self)