From b455ad0a6460569a2237f201d2be21d5792e8a6f Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Fri, 19 May 2023 12:44:24 -0400 Subject: [PATCH] Fix parallel mode check (#23409) * Fix sagemaker/distributed state * Fix correctly * Bring back -1 * Bring back local rank for distributed check * better version * Cleanest option --- src/transformers/training_args.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 3189c5f86b..b42400e57a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1613,6 +1613,7 @@ class TrainingArguments: raise ImportError( "Using the `Trainer` with `PyTorch` requires `accelerate>=0.19.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" ) + self.distributed_state = None if self.no_cuda: self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) self._n_gpu = 0 @@ -1636,7 +1637,7 @@ class TrainingArguments: if ( torch.distributed.is_available() and torch.distributed.is_initialized() - and self.distributed_state.distributed_type == DistributedType.NO + and self.parallel_mode != ParallelMode.DISTRIBUTED ): logger.warning( "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. " @@ -1728,7 +1729,9 @@ class TrainingArguments: return ParallelMode.SAGEMAKER_MODEL_PARALLEL elif is_sagemaker_dp_enabled(): return ParallelMode.SAGEMAKER_DATA_PARALLEL - elif hasattr(self, "distributed_state") and self.distributed_state.distributed_type != DistributedType.NO: + elif ( + self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO + ) or (self.distributed_state is None and self.local_rank != -1): return ParallelMode.DISTRIBUTED elif self.n_gpu > 1: return ParallelMode.NOT_DISTRIBUTED