Smdistributed trainer (#9798)
* Add a debug print * Adapt Trainer to use smdistributed if available * Forgotten parenthesis * Real check for sagemaker * Donforget to define device... * Woopsie, local)rank is defined differently * Update since local_rank has the proper value * Remove debug statement * More robust check for smdistributed * Quality * Deal with key not present error
This commit is contained in:
@@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
from .file_utils import is_torch_tpu_available
|
||||
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_sagemaker_distributed_available():
|
||||
import smdistributed.dataparallel.torch.distributed as dist
|
||||
else:
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
if is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
|
||||
try:
|
||||
if isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
@@ -138,8 +144,8 @@ def distributed_broadcast_scalars(
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
tensorized_scalar = torch.tensor(scalars).cuda()
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensorized_scalar)
|
||||
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(output_tensors, tensorized_scalar)
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
@@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
local_rank (:obj:`int`): The rank of the local process.
|
||||
"""
|
||||
if local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
dist.barrier()
|
||||
yield
|
||||
if local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
@@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler):
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
lengths: Optional[List[int]] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = torch.distributed.get_world_size()
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not torch.distributed.is_available():
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = torch.distributed.get_rank()
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_replicas = num_replicas
|
||||
|
||||
Reference in New Issue
Block a user