@@ -1815,15 +1815,10 @@ class TrainingArguments:
|
||||
The number of processes used in parallel.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
if is_torch_tpu_available():
|
||||
return xm.xrt_world_size()
|
||||
if self.distributed_state is not None:
|
||||
return self.distributed_state.num_processes
|
||||
elif is_sagemaker_mp_enabled():
|
||||
return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
|
||||
elif is_sagemaker_dp_enabled():
|
||||
return dist.get_world_size()
|
||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
return torch.distributed.get_world_size()
|
||||
return 1
|
||||
|
||||
@property
|
||||
@@ -1832,14 +1827,10 @@ class TrainingArguments:
|
||||
The index of the current process used.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
if is_torch_tpu_available():
|
||||
return xm.get_ordinal()
|
||||
if self.distributed_state is not None:
|
||||
return self.distributed_state.process_index
|
||||
elif is_sagemaker_mp_enabled():
|
||||
return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
|
||||
elif is_sagemaker_dp_enabled():
|
||||
return dist.get_rank()
|
||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
return torch.distributed.get_rank()
|
||||
return 0
|
||||
|
||||
@property
|
||||
@@ -1848,14 +1839,11 @@ class TrainingArguments:
|
||||
The index of the local process used.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
if is_torch_tpu_available():
|
||||
return xm.get_local_ordinal()
|
||||
|
||||
if self.distributed_state is not None:
|
||||
return self.distributed_state.local_process_index
|
||||
elif is_sagemaker_mp_enabled():
|
||||
return smp.local_rank()
|
||||
elif is_sagemaker_dp_enabled():
|
||||
return dist.get_rank()
|
||||
elif self.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
return self.local_rank
|
||||
return 0
|
||||
|
||||
@property
|
||||
@@ -1944,19 +1932,19 @@ class TrainingArguments:
|
||||
|
||||
"""
|
||||
if is_torch_available() and self.world_size > 1:
|
||||
main_process_desc = "main process"
|
||||
if local:
|
||||
is_main_process = self.local_process_index == 0
|
||||
main_process_desc = "main local process"
|
||||
main_process_desc = "main local process" if local else "main process"
|
||||
if self.distributed_state is not None:
|
||||
is_main_process = (
|
||||
self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
|
||||
)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
is_main_process = smp.rank() == 0
|
||||
else:
|
||||
is_main_process = self.process_index == 0
|
||||
|
||||
try:
|
||||
if not is_main_process:
|
||||
# tell all replicas to wait
|
||||
logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous(desc)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user