@@ -1815,15 +1815,10 @@ class TrainingArguments:
|
|||||||
The number of processes used in parallel.
|
The number of processes used in parallel.
|
||||||
"""
|
"""
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
if self.distributed_state is not None:
|
||||||
if is_torch_tpu_available():
|
return self.distributed_state.num_processes
|
||||||
return xm.xrt_world_size()
|
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
|
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
|
||||||
elif is_sagemaker_dp_enabled():
|
|
||||||
return dist.get_world_size()
|
|
||||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
|
||||||
return torch.distributed.get_world_size()
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1832,14 +1827,10 @@ class TrainingArguments:
|
|||||||
The index of the current process used.
|
The index of the current process used.
|
||||||
"""
|
"""
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
if self.distributed_state is not None:
|
||||||
return xm.get_ordinal()
|
return self.distributed_state.process_index
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
|
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
|
||||||
elif is_sagemaker_dp_enabled():
|
|
||||||
return dist.get_rank()
|
|
||||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
|
||||||
return torch.distributed.get_rank()
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1848,14 +1839,11 @@ class TrainingArguments:
|
|||||||
The index of the local process used.
|
The index of the local process used.
|
||||||
"""
|
"""
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
|
||||||
return xm.get_local_ordinal()
|
if self.distributed_state is not None:
|
||||||
|
return self.distributed_state.local_process_index
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
return smp.local_rank()
|
return smp.local_rank()
|
||||||
elif is_sagemaker_dp_enabled():
|
|
||||||
return dist.get_rank()
|
|
||||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
|
||||||
return self.local_rank
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1944,19 +1932,19 @@ class TrainingArguments:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if is_torch_available() and self.world_size > 1:
|
if is_torch_available() and self.world_size > 1:
|
||||||
main_process_desc = "main process"
|
main_process_desc = "main local process" if local else "main process"
|
||||||
if local:
|
if self.distributed_state is not None:
|
||||||
is_main_process = self.local_process_index == 0
|
is_main_process = (
|
||||||
main_process_desc = "main local process"
|
self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
|
||||||
|
)
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
is_main_process = smp.rank() == 0
|
is_main_process = smp.rank() == 0
|
||||||
else:
|
|
||||||
is_main_process = self.process_index == 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
# tell all replicas to wait
|
# tell all replicas to wait
|
||||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.rendezvous(desc)
|
xm.rendezvous(desc)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user