From 1420b5ff675ccdc3296c6776b339a08a22d2e941 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 29 Jan 2021 08:18:04 -0800 Subject: [PATCH] refactor deepspeed setup devices (#9880) --- src/transformers/training_args.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b426f2c430..3fe0d137b3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -535,6 +535,20 @@ class TrainingArguments: self.local_rank = dist.get_local_rank() device = torch.device("cuda", self.local_rank) self._n_gpu = 1 + elif self.deepspeed: + # deepspeed performs its own DDP internally, and requires the program to be started with: + # deepspeed ./program.py + # rather than: + # python -m torch.distributed.launch --nproc_per_node=2 ./program.py + from .integrations import is_deepspeed_available + + if not is_deepspeed_available(): + raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") + import deepspeed + + deepspeed.init_distributed() + device = torch.device("cuda", self.local_rank) + self._n_gpu = 1 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` @@ -549,21 +563,7 @@ class TrainingArguments: else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs - # - # deepspeed performs its own DDP internally, and requires the program to be started with: - # deepspeed ./program.py - # rather than: - # python -m torch.distributed.launch --nproc_per_node=2 ./program.py - if self.deepspeed: - from .integrations import is_deepspeed_available - - if not is_deepspeed_available(): - raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") - import deepspeed - - deepspeed.init_distributed() - else: - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1