Smdistributed trainer (#9798)

* Add a debug print

* Adapt Trainer to use smdistributed if available

* Forgotten parenthesis

* Real check for sagemaker

* Donforget to define device...

* Woopsie, local)rank is defined differently

* Update since local_rank has the proper value

* Remove debug statement

* More robust check for smdistributed

* Quality

* Deal with key not present error
This commit is contained in:
Sylvain Gugger
2021-01-26 10:28:21 -05:00
committed by GitHub
parent 897a24c869
commit 0d0efd3a0e
5 changed files with 85 additions and 29 deletions

View File

@@ -28,10 +28,16 @@ from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler
from .file_utils import is_torch_tpu_available
from .file_utils import is_sagemaker_distributed_available, is_torch_tpu_available
from .utils import logging
if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
else:
import torch.distributed as dist
if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
@@ -121,8 +127,8 @@ def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int]
try:
if isinstance(tensor, (tuple, list)):
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
@@ -138,8 +144,8 @@ def distributed_broadcast_scalars(
) -> torch.Tensor:
try:
tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar)
output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
dist.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
@@ -167,10 +173,10 @@ def torch_distributed_zero_first(local_rank: int):
local_rank (:obj:`int`): The rank of the local process.
"""
if local_rank not in [-1, 0]:
torch.distributed.barrier()
dist.barrier()
yield
if local_rank == 0:
torch.distributed.barrier()
dist.barrier()
class SequentialDistributedSampler(Sampler):
@@ -185,13 +191,13 @@ class SequentialDistributedSampler(Sampler):
def __init__(self, dataset, num_replicas=None, rank=None):
if num_replicas is None:
if not torch.distributed.is_available():
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
num_replicas = dist.get_world_size()
if rank is None:
if not torch.distributed.is_available():
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -480,13 +486,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
lengths: Optional[List[int]] = None,
):
if num_replicas is None:
if not torch.distributed.is_available():
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
num_replicas = dist.get_world_size()
if rank is None:
if not torch.distributed.is_available():
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
rank = dist.get_rank()
self.dataset = dataset
self.batch_size = batch_size
self.num_replicas = num_replicas