From 127e81c272060709e451627903082fd3b55a7039 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 21 Jun 2023 11:51:27 -0400 Subject: [PATCH] Remove redundant code from TrainingArgs (#24401) Remove redundant code --- src/transformers/training_args.py | 38 +++++++++++-------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index dc692f5aa7..b27e9bba4c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1815,15 +1815,10 @@ class TrainingArguments: The number of processes used in parallel. """ requires_backends(self, ["torch"]) - - if is_torch_tpu_available(): - return xm.xrt_world_size() + if self.distributed_state is not None: + return self.distributed_state.num_processes elif is_sagemaker_mp_enabled(): return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size() - elif is_sagemaker_dp_enabled(): - return dist.get_world_size() - elif self.parallel_mode == ParallelMode.DISTRIBUTED: - return torch.distributed.get_world_size() return 1 @property @@ -1832,14 +1827,10 @@ class TrainingArguments: The index of the current process used. """ requires_backends(self, ["torch"]) - if is_torch_tpu_available(): - return xm.get_ordinal() + if self.distributed_state is not None: + return self.distributed_state.process_index elif is_sagemaker_mp_enabled(): return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank() - elif is_sagemaker_dp_enabled(): - return dist.get_rank() - elif self.parallel_mode == ParallelMode.DISTRIBUTED: - return torch.distributed.get_rank() return 0 @property @@ -1848,14 +1839,11 @@ class TrainingArguments: The index of the local process used. """ requires_backends(self, ["torch"]) - if is_torch_tpu_available(): - return xm.get_local_ordinal() + + if self.distributed_state is not None: + return self.distributed_state.local_process_index elif is_sagemaker_mp_enabled(): return smp.local_rank() - elif is_sagemaker_dp_enabled(): - return dist.get_rank() - elif self.parallel_mode == ParallelMode.DISTRIBUTED: - return self.local_rank return 0 @property @@ -1944,19 +1932,19 @@ class TrainingArguments: """ if is_torch_available() and self.world_size > 1: - main_process_desc = "main process" - if local: - is_main_process = self.local_process_index == 0 - main_process_desc = "main local process" + main_process_desc = "main local process" if local else "main process" + if self.distributed_state is not None: + is_main_process = ( + self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process + ) elif is_sagemaker_mp_enabled(): is_main_process = smp.rank() == 0 - else: - is_main_process = self.process_index == 0 try: if not is_main_process: # tell all replicas to wait logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}") + if is_torch_tpu_available(): xm.rendezvous(desc) else: