Smp grad accum (#10488)
* Fix gradient accumulation for SM Model Parallelism * Style and divide loss by grad accum steps
This commit is contained in:
@@ -37,9 +37,10 @@ if is_smdistributed_available():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
@smp.step()
|
@smp.step()
|
||||||
def forward_backward(model, inputs):
|
def forward_backward(model, inputs, gradient_accumulation_steps=1):
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||||
|
loss /= gradient_accumulation_steps
|
||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -73,8 +74,6 @@ class SageMakerTrainer(Trainer):
|
|||||||
def __init__(self, args=None, **kwargs):
|
def __init__(self, args=None, **kwargs):
|
||||||
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
||||||
super().__init__(args=args, **kwargs)
|
super().__init__(args=args, **kwargs)
|
||||||
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
|
||||||
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
|
||||||
|
|
||||||
def is_world_process_zero(self) -> bool:
|
def is_world_process_zero(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -108,7 +107,7 @@ class SageMakerTrainer(Trainer):
|
|||||||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
# Wrapping the base model twice in a DistributedModel will raise an error.
|
||||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
||||||
return self.model_wrapped
|
return self.model_wrapped
|
||||||
return smp.DistributedModel(model)
|
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
||||||
else:
|
else:
|
||||||
return super()._wrap_model(model)
|
return super()._wrap_model(model)
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ class SageMakerTrainer(Trainer):
|
|||||||
if self.is_model_parallel_enabled:
|
if self.is_model_parallel_enabled:
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
loss_mb = forward_backward(model, inputs)
|
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||||
else:
|
else:
|
||||||
return super().training_step(model, inputs)
|
return super().training_step(model, inputs)
|
||||||
|
|||||||
@@ -87,3 +87,7 @@ class SageMakerTrainingArguments(TrainingArguments):
|
|||||||
@property
|
@property
|
||||||
def place_model_on_device(self):
|
def place_model_on_device(self):
|
||||||
return not (is_smdistributed_available() and self.mp_parameters != "")
|
return not (is_smdistributed_available() and self.mp_parameters != "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _no_sync_in_gradient_accumulation(self):
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1039,7 +1039,7 @@ class Trainer:
|
|||||||
if (
|
if (
|
||||||
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
((step + 1) % self.args.gradient_accumulation_steps != 0)
|
||||||
and self.args.local_rank != -1
|
and self.args.local_rank != -1
|
||||||
and not self.args.deepspeed
|
and self.args._no_sync_in_gradient_accumulation
|
||||||
):
|
):
|
||||||
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
|
||||||
with model.no_sync():
|
with model.no_sync():
|
||||||
|
|||||||
@@ -737,6 +737,13 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _no_sync_in_gradient_accumulation(self):
|
||||||
|
"""
|
||||||
|
Whether or not to use no_sync for the gradients when doing gradient accumulation.
|
||||||
|
"""
|
||||||
|
return not self.deepspeed
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||||
|
|||||||
Reference in New Issue
Block a user