Smdistributed trainer (#9798)

* Add a debug print

* Adapt Trainer to use smdistributed if available

* Forgotten parenthesis

* Real check for sagemaker

* Donforget to define device...

* Woopsie, local)rank is defined differently

* Update since local_rank has the proper value

* Remove debug statement

* More robust check for smdistributed

* Quality

* Deal with key not present error
This commit is contained in:
Sylvain Gugger
2021-01-26 10:28:21 -05:00
committed by GitHub
parent 897a24c869
commit 0d0efd3a0e
5 changed files with 85 additions and 29 deletions

View File

@@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .file_utils import (
cached_property,
is_sagemaker_distributed_available,
is_torch_available,
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, SchedulerType
from .utils import logging
@@ -493,6 +499,13 @@ class TrainingArguments:
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
elif is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
dist.init_process_group()
self.local_rank = dist.get_local_rank()
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
@@ -566,6 +579,8 @@ class TrainingArguments:
"""
if is_torch_tpu_available():
return ParallelMode.TPU
elif is_sagemaker_distributed_available():
return ParallelMode.SAGEMAKER_DISTRIBUTED
elif self.local_rank != -1:
return ParallelMode.DISTRIBUTED
elif self.n_gpu > 1:
@@ -607,4 +622,5 @@ class ParallelMode(Enum):
NOT_PARALLEL = "not_parallel"
NOT_DISTRIBUTED = "not_distributed"
DISTRIBUTED = "distributed"
SAGEMAKER_DISTRIBUTED = "sm_distributed"
TPU = "tpu"