Add DistributedSamplerWithLoop (#10746)

* Add DistributedSamplerWithLoop

* Fix typo

* Test and small fix
This commit is contained in:
Sylvain Gugger
2021-03-16 11:22:39 -04:00
committed by GitHub
parent 1449222217
commit a0a027c2ed
5 changed files with 93 additions and 20 deletions

View File

@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
from ..trainer import Trainer from ..trainer import Trainer
from ..trainer_pt_utils import ( from ..trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
SequentialDistributedSampler, SequentialDistributedSampler,
nested_detach, nested_detach,
nested_numpify, nested_numpify,
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
return DistributedLengthGroupedSampler( return DistributedLengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank() self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
) )
elif not self.args.dataloader_drop_last:
return DistributedSamplerWithLoop(
self.train_dataset,
self.args.per_device_train_batch_size,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
)
else: else:
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
else: else:

View File

@@ -77,6 +77,7 @@ from .trainer_callback import (
) )
from .trainer_pt_utils import ( from .trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
@@ -491,24 +492,10 @@ class Trainer:
): ):
return None return None
# Gather the number of processes and this process index.
if self.args.parallel_mode == ParallelMode.TPU:
num_processes = xm.xrt_world_size()
process_index = xm.get_ordinal()
elif (
self.args.parallel_mode == ParallelMode.DISTRIBUTED
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
):
num_processes = dist.get_world_size()
process_index = dist.get_rank()
else:
num_processes = 1
process_index = 0
# Build the sampler. # Build the sampler.
if self.args.group_by_length: if self.args.group_by_length:
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
if num_processes <= 1: if self.args.world_size <= 1:
return LengthGroupedSampler( return LengthGroupedSampler(
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
) )
@@ -516,16 +503,26 @@ class Trainer:
return DistributedLengthGroupedSampler( return DistributedLengthGroupedSampler(
self.train_dataset, self.train_dataset,
self.args.train_batch_size, self.args.train_batch_size,
num_replicas=num_processes, num_replicas=self.args.world_size,
rank=process_index, rank=self.args.process_index,
model_input_name=model_input_name, model_input_name=model_input_name,
) )
else: else:
if num_processes <= 1: if self.args.world_size <= 1:
return RandomSampler(self.train_dataset) return RandomSampler(self.train_dataset)
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
return DistributedSamplerWithLoop(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
num_replicas=self.args.world_size,
rank=self.args.process_index,
)
else: else:
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index) return DistributedSampler(
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
""" """

View File

@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier() dist.barrier()
class DistributedSamplerWithLoop(DistributedSampler):
"""
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
shuffled samples to make each process have a round multiple of batch_size samples.
Args:
dataset (:obj:`torch.utils.data.Dataset`):
Dataset used for sampling.
batch_size (:obj:`int`):
The batch size used with this sampler
kwargs:
All other keyword arguments passed to :obj:`DistributedSampler`.
"""
def __init__(self, dataset, batch_size, **kwargs):
super().__init__(dataset, **kwargs)
self.batch_size = batch_size
def __iter__(self):
indices = list(super().__iter__())
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
# of the world size, so we skip those.
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
indices += indices[start_remainder : start_remainder + remainder]
return iter(indices)
class SequentialDistributedSampler(Sampler): class SequentialDistributedSampler(Sampler):
""" """
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
return self.num_samples return self.num_samples
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
if xm.xrt_world_size() <= 1: if xm.xrt_world_size() <= 1:
return RandomSampler(dataset) return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

View File

@@ -742,6 +742,20 @@ class TrainingArguments:
return torch.distributed.get_world_size() return torch.distributed.get_world_size()
return 1 return 1
@property
@torch_required
def process_index(self):
"""
The number of processes used in parallel.
"""
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_distributed_available():
return sm_dist.get_rank()
elif self.local_rank != -1:
return torch.distributed.get_rank()
return 0
@property @property
def place_model_on_device(self): def place_model_on_device(self):
""" """

View File

@@ -27,6 +27,7 @@ if is_torch_available():
from transformers.modeling_outputs import SequenceClassifierOutput from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.trainer_pt_utils import ( from transformers.trainer_pt_utils import (
DistributedLengthGroupedSampler, DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer, DistributedTensorGatherer,
LabelSmoother, LabelSmoother,
LengthGroupedSampler, LengthGroupedSampler,
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias'] ['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
) )
# fmt: on # fmt: on
def test_distributed_sampler_with_loop(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0)
shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1)
# Set seeds
shard1.set_epoch(0)
shard2.set_epoch(0)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = []
for sample1, sample2 in zip(samples1, samples2):
total += [sample1, sample2]
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))