Smdistributed trainer (#9798)
* Add a debug print * Adapt Trainer to use smdistributed if available * Forgotten parenthesis * Real check for sagemaker * Donforget to define device... * Woopsie, local)rank is defined differently * Update since local_rank has the proper value * Remove debug statement * More robust check for smdistributed * Quality * Deal with key not present error
This commit is contained in:
@@ -297,6 +297,20 @@ def is_pandas_available():
|
|||||||
return importlib.util.find_spec("pandas") is not None
|
return importlib.util.find_spec("pandas") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_sagemaker_distributed_available():
|
||||||
|
# Get the sagemaker specific env variable.
|
||||||
|
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
||||||
|
try:
|
||||||
|
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
||||||
|
sagemaker_params = json.loads(sagemaker_params)
|
||||||
|
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
# Lastly, check if the `smdistributed` module is present.
|
||||||
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
def torch_only_method(fn):
|
def torch_only_method(fn):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if not _torch_available:
|
if not _torch_available:
|
||||||
|
|||||||
@@ -51,7 +51,14 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||||
|
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .file_utils import WEIGHTS_NAME, is_apex_available, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
from .file_utils import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_apex_available,
|
||||||
|
is_datasets_available,
|
||||||
|
is_in_notebook,
|
||||||
|
is_sagemaker_distributed_available,
|
||||||
|
is_torch_tpu_available,
|
||||||
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
from .optimization import Adafactor, AdamW, get_scheduler
|
from .optimization import Adafactor, AdamW, get_scheduler
|
||||||
@@ -125,6 +132,11 @@ if is_fairscale_available():
|
|||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||||
|
|
||||||
|
if is_sagemaker_distributed_available():
|
||||||
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
|
||||||
|
else:
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import optuna
|
import optuna
|
||||||
@@ -428,9 +440,12 @@ class Trainer:
|
|||||||
if self.args.parallel_mode == ParallelMode.TPU:
|
if self.args.parallel_mode == ParallelMode.TPU:
|
||||||
num_processes = xm.xrt_world_size()
|
num_processes = xm.xrt_world_size()
|
||||||
process_index = xm.get_ordinal()
|
process_index = xm.get_ordinal()
|
||||||
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
elif (
|
||||||
num_processes = torch.distributed.get_world_size()
|
self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
process_index = torch.distributed.get_rank()
|
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||||
|
):
|
||||||
|
num_processes = dist.get_world_size()
|
||||||
|
process_index = dist.get_rank()
|
||||||
else:
|
else:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
process_index = 0
|
process_index = 0
|
||||||
@@ -743,6 +758,8 @@ class Trainer:
|
|||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if self.sharded_dpp:
|
if self.sharded_dpp:
|
||||||
model = ShardedDDP(model, self.optimizer)
|
model = ShardedDDP(model, self.optimizer)
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
model,
|
model,
|
||||||
@@ -767,14 +784,13 @@ class Trainer:
|
|||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
world_size = xm.xrt_world_size()
|
||||||
|
elif self.args.local_rank != -1:
|
||||||
|
world_size = dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
total_train_batch_size = (
|
world_size = 1
|
||||||
self.args.train_batch_size
|
|
||||||
* self.args.gradient_accumulation_steps
|
|
||||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
|
||||||
num_examples = (
|
num_examples = (
|
||||||
self.num_examples(train_dataloader)
|
self.num_examples(train_dataloader)
|
||||||
if train_dataset_is_sized
|
if train_dataset_is_sized
|
||||||
@@ -1302,7 +1318,7 @@ class Trainer:
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.is_master_ordinal(local=False)
|
return xm.is_master_ordinal(local=False)
|
||||||
else:
|
else:
|
||||||
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
return self.args.local_rank == -1 or dist.get_rank() == 0
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None):
|
def save_model(self, output_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
@@ -1542,7 +1558,7 @@ class Trainer:
|
|||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
world_size = xm.xrt_world_size()
|
world_size = xm.xrt_world_size()
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
world_size = max(1, world_size)
|
world_size = max(1, world_size)
|
||||||
|
|
||||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||||
|
|||||||
@@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset
|
|||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||||
|
|
||||||
from .file_utils import is_torch_tpu_available
|
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_sagemaker_distributed_available():
|
||||||
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
else:
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
@@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
|
|||||||
try:
|
try:
|
||||||
if isinstance(tensor, (tuple, list)):
|
if isinstance(tensor, (tuple, list)):
|
||||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||||
torch.distributed.all_gather(output_tensors, tensor)
|
dist.all_gather(output_tensors, tensor)
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
@@ -138,8 +144,8 @@ def distributed_broadcast_scalars(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
try:
|
try:
|
||||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
||||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
dist.all_gather(output_tensors, tensorized_scalar)
|
||||||
concat = torch.cat(output_tensors, dim=0)
|
concat = torch.cat(output_tensors, dim=0)
|
||||||
|
|
||||||
# truncate the dummy elements added by SequentialDistributedSampler
|
# truncate the dummy elements added by SequentialDistributedSampler
|
||||||
@@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int):
|
|||||||
local_rank (:obj:`int`): The rank of the local process.
|
local_rank (:obj:`int`): The rank of the local process.
|
||||||
"""
|
"""
|
||||||
if local_rank not in [-1, 0]:
|
if local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier()
|
dist.barrier()
|
||||||
yield
|
yield
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
torch.distributed.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
class SequentialDistributedSampler(Sampler):
|
class SequentialDistributedSampler(Sampler):
|
||||||
@@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler):
|
|||||||
|
|
||||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not torch.distributed.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
num_replicas = torch.distributed.get_world_size()
|
num_replicas = dist.get_world_size()
|
||||||
if rank is None:
|
if rank is None:
|
||||||
if not torch.distributed.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
rank = torch.distributed.get_rank()
|
rank = dist.get_rank()
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
lengths: Optional[List[int]] = None,
|
lengths: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if num_replicas is None:
|
if num_replicas is None:
|
||||||
if not torch.distributed.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
num_replicas = torch.distributed.get_world_size()
|
num_replicas = dist.get_world_size()
|
||||||
if rank is None:
|
if rank is None:
|
||||||
if not torch.distributed.is_available():
|
if not dist.is_available():
|
||||||
raise RuntimeError("Requires distributed package to be available")
|
raise RuntimeError("Requires distributed package to be available")
|
||||||
rank = torch.distributed.get_rank()
|
rank = dist.get_rank()
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
|
from .file_utils import is_sagemaker_distributed_available, is_tf_available, is_torch_available, is_torch_tpu_available
|
||||||
from .tokenization_utils_base import ExplicitEnum
|
from .tokenization_utils_base import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
@@ -187,6 +187,10 @@ def total_processes_number(local_rank):
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
|
||||||
|
return dist.get_world_size()
|
||||||
elif local_rank != -1 and is_torch_available():
|
elif local_rank != -1 and is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,13 @@ from dataclasses import asdict, dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
from .file_utils import (
|
||||||
|
cached_property,
|
||||||
|
is_sagemaker_distributed_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_tpu_available,
|
||||||
|
torch_required,
|
||||||
|
)
|
||||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -493,6 +499,13 @@ class TrainingArguments:
|
|||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
device = xm.xla_device()
|
device = xm.xla_device()
|
||||||
self._n_gpu = 0
|
self._n_gpu = 0
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
import smdistributed.dataparallel.torch.distributed as dist
|
||||||
|
|
||||||
|
dist.init_process_group()
|
||||||
|
self.local_rank = dist.get_local_rank()
|
||||||
|
device = torch.device("cuda", self.local_rank)
|
||||||
|
self._n_gpu = 1
|
||||||
elif self.local_rank == -1:
|
elif self.local_rank == -1:
|
||||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
@@ -566,6 +579,8 @@ class TrainingArguments:
|
|||||||
"""
|
"""
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return ParallelMode.TPU
|
return ParallelMode.TPU
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
return ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||||
elif self.local_rank != -1:
|
elif self.local_rank != -1:
|
||||||
return ParallelMode.DISTRIBUTED
|
return ParallelMode.DISTRIBUTED
|
||||||
elif self.n_gpu > 1:
|
elif self.n_gpu > 1:
|
||||||
@@ -607,4 +622,5 @@ class ParallelMode(Enum):
|
|||||||
NOT_PARALLEL = "not_parallel"
|
NOT_PARALLEL = "not_parallel"
|
||||||
NOT_DISTRIBUTED = "not_distributed"
|
NOT_DISTRIBUTED = "not_distributed"
|
||||||
DISTRIBUTED = "distributed"
|
DISTRIBUTED = "distributed"
|
||||||
|
SAGEMAKER_DISTRIBUTED = "sm_distributed"
|
||||||
TPU = "tpu"
|
TPU = "tpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user