Add DistributedSamplerWithLoop (#10746)
* Add DistributedSamplerWithLoop * Fix typo * Test and small fix
This commit is contained in:
@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
|
||||
from ..trainer import Trainer
|
||||
from ..trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
SequentialDistributedSampler,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
|
||||
return DistributedLengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
||||
)
|
||||
elif not self.args.dataloader_drop_last:
|
||||
return DistributedSamplerWithLoop(
|
||||
self.train_dataset,
|
||||
self.args.per_device_train_batch_size,
|
||||
num_replicas=smp.dp_size(),
|
||||
rank=smp.dp_rank(),
|
||||
)
|
||||
else:
|
||||
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||
else:
|
||||
|
||||
@@ -77,6 +77,7 @@ from .trainer_callback import (
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
@@ -491,24 +492,10 @@ class Trainer:
|
||||
):
|
||||
return None
|
||||
|
||||
# Gather the number of processes and this process index.
|
||||
if self.args.parallel_mode == ParallelMode.TPU:
|
||||
num_processes = xm.xrt_world_size()
|
||||
process_index = xm.get_ordinal()
|
||||
elif (
|
||||
self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
|
||||
):
|
||||
num_processes = dist.get_world_size()
|
||||
process_index = dist.get_rank()
|
||||
else:
|
||||
num_processes = 1
|
||||
process_index = 0
|
||||
|
||||
# Build the sampler.
|
||||
if self.args.group_by_length:
|
||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||
if num_processes <= 1:
|
||||
if self.args.world_size <= 1:
|
||||
return LengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
|
||||
)
|
||||
@@ -516,16 +503,26 @@ class Trainer:
|
||||
return DistributedLengthGroupedSampler(
|
||||
self.train_dataset,
|
||||
self.args.train_batch_size,
|
||||
num_replicas=num_processes,
|
||||
rank=process_index,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
model_input_name=model_input_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if num_processes <= 1:
|
||||
if self.args.world_size <= 1:
|
||||
return RandomSampler(self.train_dataset)
|
||||
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
|
||||
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
||||
return DistributedSamplerWithLoop(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
else:
|
||||
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
|
||||
return DistributedSampler(
|
||||
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
|
||||
)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
|
||||
@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
|
||||
dist.barrier()
|
||||
|
||||
|
||||
class DistributedSamplerWithLoop(DistributedSampler):
|
||||
"""
|
||||
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
|
||||
shuffled samples to make each process have a round multiple of batch_size samples.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`torch.utils.data.Dataset`):
|
||||
Dataset used for sampling.
|
||||
batch_size (:obj:`int`):
|
||||
The batch size used with this sampler
|
||||
kwargs:
|
||||
All other keyword arguments passed to :obj:`DistributedSampler`.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, **kwargs):
|
||||
super().__init__(dataset, **kwargs)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(super().__iter__())
|
||||
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
|
||||
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
|
||||
# of the world size, so we skip those.
|
||||
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
|
||||
indices += indices[start_remainder : start_remainder + remainder]
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class SequentialDistributedSampler(Sampler):
|
||||
"""
|
||||
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
|
||||
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
|
||||
return self.num_samples
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
|
||||
@@ -742,6 +742,20 @@ class TrainingArguments:
|
||||
return torch.distributed.get_world_size()
|
||||
return 1
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def process_index(self):
|
||||
"""
|
||||
The number of processes used in parallel.
|
||||
"""
|
||||
if is_torch_tpu_available():
|
||||
return xm.get_ordinal()
|
||||
elif is_sagemaker_distributed_available():
|
||||
return sm_dist.get_rank()
|
||||
elif self.local_rank != -1:
|
||||
return torch.distributed.get_rank()
|
||||
return 0
|
||||
|
||||
@property
|
||||
def place_model_on_device(self):
|
||||
"""
|
||||
|
||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedSamplerWithLoop,
|
||||
DistributedTensorGatherer,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_distributed_sampler_with_loop(self):
|
||||
batch_size = 16
|
||||
for length in [23, 64, 123]:
|
||||
dataset = list(range(length))
|
||||
shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0)
|
||||
shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1)
|
||||
|
||||
# Set seeds
|
||||
shard1.set_epoch(0)
|
||||
shard2.set_epoch(0)
|
||||
|
||||
# Sample
|
||||
samples1 = list(shard1)
|
||||
samples2 = list(shard2)
|
||||
|
||||
self.assertTrue(len(samples1) % batch_size == 0)
|
||||
self.assertTrue(len(samples2) % batch_size == 0)
|
||||
|
||||
total = []
|
||||
for sample1, sample2 in zip(samples1, samples2):
|
||||
total += [sample1, sample2]
|
||||
|
||||
self.assertEqual(set(total[:length]), set(dataset))
|
||||
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
|
||||
|
||||
Reference in New Issue
Block a user