Smangrul/accelerate mp integrate (#23148)
* mixed precision support via accelerate * fix issues * fix for the sharded ddp case * fix flax and tf failing tests * `refactor the place to create `Accelerator` object * address comments by removing debugging print statements
This commit is contained in:
committed by
GitHub
parent
de9255de27
commit
9f0646a555
@@ -212,6 +212,8 @@ if is_accelerate_available():
|
|||||||
if version.parse(accelerate_version) >= version.parse("0.16"):
|
if version.parse(accelerate_version) >= version.parse("0.16"):
|
||||||
from accelerate import skip_first_batches
|
from accelerate import skip_first_batches
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
@@ -337,6 +339,9 @@ class Trainer:
|
|||||||
self.deepspeed = None
|
self.deepspeed = None
|
||||||
self.is_in_train = False
|
self.is_in_train = False
|
||||||
|
|
||||||
|
# create accelerator object
|
||||||
|
self.accelerator = Accelerator()
|
||||||
|
|
||||||
# memory metrics - must set up as early as possible
|
# memory metrics - must set up as early as possible
|
||||||
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
|
||||||
self._memory_tracker.start()
|
self._memory_tracker.start()
|
||||||
@@ -607,7 +612,7 @@ class Trainer:
|
|||||||
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.fp16 or args.bf16:
|
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
|
||||||
if args.half_precision_backend == "auto":
|
if args.half_precision_backend == "auto":
|
||||||
if args.device == torch.device("cpu"):
|
if args.device == torch.device("cpu"):
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
@@ -624,6 +629,7 @@ class Trainer:
|
|||||||
self.do_grad_scaling = False
|
self.do_grad_scaling = False
|
||||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
|
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
|
||||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||||
|
if self.sharded_ddp is not None:
|
||||||
if args.half_precision_backend == "cuda_amp":
|
if args.half_precision_backend == "cuda_amp":
|
||||||
self.use_cuda_amp = True
|
self.use_cuda_amp = True
|
||||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||||
@@ -647,7 +653,7 @@ class Trainer:
|
|||||||
elif args.half_precision_backend == "cpu_amp":
|
elif args.half_precision_backend == "cpu_amp":
|
||||||
self.use_cpu_amp = True
|
self.use_cpu_amp = True
|
||||||
self.amp_dtype = torch.bfloat16
|
self.amp_dtype = torch.bfloat16
|
||||||
else:
|
elif args.half_precision_backend == "apex":
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Using FP16 with APEX but APEX is not installed, please refer to"
|
"Using FP16 with APEX but APEX is not installed, please refer to"
|
||||||
@@ -1801,6 +1807,11 @@ class Trainer:
|
|||||||
if delay_optimizer_creation:
|
if delay_optimizer_creation:
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
|
# prepare using `accelerator` prepare
|
||||||
|
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||||
|
self.model, self.optimizer, self.lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
self._load_optimizer_and_scheduler(resume_from_checkpoint)
|
||||||
|
|
||||||
@@ -2013,10 +2024,15 @@ class Trainer:
|
|||||||
elif hasattr(model, "clip_grad_norm_"):
|
elif hasattr(model, "clip_grad_norm_"):
|
||||||
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
|
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
|
||||||
model.clip_grad_norm_(args.max_grad_norm)
|
model.clip_grad_norm_(args.max_grad_norm)
|
||||||
else:
|
elif self.use_apex:
|
||||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||||
nn.utils.clip_grad_norm_(
|
nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
|
amp.master_params(self.optimizer),
|
||||||
|
args.max_grad_norm,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.accelerator.clip_grad_norm_(
|
||||||
|
model.parameters(),
|
||||||
args.max_grad_norm,
|
args.max_grad_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2802,7 +2818,7 @@ class Trainer:
|
|||||||
# loss gets scaled under gradient_accumulation_steps in deepspeed
|
# loss gets scaled under gradient_accumulation_steps in deepspeed
|
||||||
loss = self.deepspeed.backward(loss)
|
loss = self.deepspeed.backward(loss)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
return loss.detach()
|
return loss.detach()
|
||||||
|
|
||||||
|
|||||||
@@ -1562,6 +1562,15 @@ class TrainingArguments:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if training args is specified, it will override the one specified in the accelerate config
|
||||||
|
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
|
||||||
|
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
||||||
|
if self.fp16:
|
||||||
|
mixed_precision_dtype = "fp16"
|
||||||
|
elif self.bf16:
|
||||||
|
mixed_precision_dtype = "bf16"
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
self_as_dict = asdict(self)
|
self_as_dict = asdict(self)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user