diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ea91ddfe81..4c13ad0af4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1641,7 +1641,10 @@ class TrainingArguments: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) + if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"): + torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta) + else: + torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) device = torch.device("cuda", self.local_rank) self._n_gpu = 1