From a8aad0ec938778bf41df1e6842a7baab81776c64 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Wed, 19 Apr 2023 14:37:16 -0400 Subject: [PATCH] Fixup multigpu local_rank (#22869) Fixup multigpu tests --- src/transformers/training_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index fc2ba42786..4ab1829859 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1537,9 +1537,7 @@ class TrainingArguments: ) if self.no_cuda: self.distributed_state = PartialState(cpu=True) - device = self.distributed_state.device self._n_gpu = 0 - self.local_rank = self.distributed_state.local_process_index elif is_sagemaker_mp_enabled(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) @@ -1548,11 +1546,12 @@ class TrainingArguments: elif self.deepspeed: self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) self._n_gpu = 1 - device = self.distributed_state.device else: self.distributed_state = PartialState(backend=self.xpu_backend) - device = self.distributed_state.device self._n_gpu = 1 + if not is_sagemaker_mp_enabled(): + device = self.distributed_state.device + self.local_rank = self.distributed_state.local_process_index if ( torch.distributed.is_available() and torch.distributed.is_initialized()