Smdistributed trainer (#9798)
* Add a debug print * Adapt Trainer to use smdistributed if available * Forgotten parenthesis * Real check for sagemaker * Donforget to define device... * Woopsie, local)rank is defined differently * Update since local_rank has the proper value * Remove debug statement * More robust check for smdistributed * Quality * Deal with key not present error
This commit is contained in:
@@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .file_utils import (
|
||||
cached_property,
|
||||
is_sagemaker_distributed_available,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||
from .utils import logging
|
||||
|
||||
@@ -493,6 +499,13 @@ class TrainingArguments:
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
self._n_gpu = 0
|
||||
elif is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
|
||||
dist.init_process_group()
|
||||
self.local_rank = dist.get_local_rank()
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
self._n_gpu = 1
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
@@ -566,6 +579,8 @@ class TrainingArguments:
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
return ParallelMode.TPU
|
||||
elif is_sagemaker_distributed_available():
|
||||
return ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||
elif self.local_rank != -1:
|
||||
return ParallelMode.DISTRIBUTED
|
||||
elif self.n_gpu > 1:
|
||||
@@ -607,4 +622,5 @@ class ParallelMode(Enum):
|
||||
NOT_PARALLEL = "not_parallel"
|
||||
NOT_DISTRIBUTED = "not_distributed"
|
||||
DISTRIBUTED = "distributed"
|
||||
SAGEMAKER_DISTRIBUTED = "sm_distributed"
|
||||
TPU = "tpu"
|
||||
|
||||
Reference in New Issue
Block a user