Sagemaker Model Parallel tensoboard writing fix (#10403)
* Added tb fix * Removed local rank condition * Updated reference to args
This commit is contained in:
@@ -71,11 +71,21 @@ if is_smdistributed_available():
|
|||||||
|
|
||||||
class SageMakerTrainer(Trainer):
|
class SageMakerTrainer(Trainer):
|
||||||
def __init__(self, args=None, **kwargs):
|
def __init__(self, args=None, **kwargs):
|
||||||
|
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
||||||
super().__init__(args=args, **kwargs)
|
super().__init__(args=args, **kwargs)
|
||||||
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
|
|
||||||
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
||||||
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
||||||
|
|
||||||
|
def is_world_process_zero(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether or not this process is the global main process (when training in a distributed fashion on several
|
||||||
|
machines, this is only going to be :obj:`True` for one process).
|
||||||
|
"""
|
||||||
|
if self.is_model_parallel_enabled:
|
||||||
|
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
|
||||||
|
else:
|
||||||
|
return super.is_world_process_zero()
|
||||||
|
|
||||||
def _get_train_sampler(self):
|
def _get_train_sampler(self):
|
||||||
if self.is_model_parallel_enabled:
|
if self.is_model_parallel_enabled:
|
||||||
if self.args.group_by_length:
|
if self.args.group_by_length:
|
||||||
|
|||||||
Reference in New Issue
Block a user