Add DistributedSamplerWithLoop (#10746)
* Add DistributedSamplerWithLoop * Fix typo * Test and small fix
This commit is contained in:
@@ -26,6 +26,7 @@ from ..modeling_utils import PreTrainedModel, unwrap_model
|
|||||||
from ..trainer import Trainer
|
from ..trainer import Trainer
|
||||||
from ..trainer_pt_utils import (
|
from ..trainer_pt_utils import (
|
||||||
DistributedLengthGroupedSampler,
|
DistributedLengthGroupedSampler,
|
||||||
|
DistributedSamplerWithLoop,
|
||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
@@ -97,6 +98,13 @@ class SageMakerTrainer(Trainer):
|
|||||||
return DistributedLengthGroupedSampler(
|
return DistributedLengthGroupedSampler(
|
||||||
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
||||||
)
|
)
|
||||||
|
elif not self.args.dataloader_drop_last:
|
||||||
|
return DistributedSamplerWithLoop(
|
||||||
|
self.train_dataset,
|
||||||
|
self.args.per_device_train_batch_size,
|
||||||
|
num_replicas=smp.dp_size(),
|
||||||
|
rank=smp.dp_rank(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ from .trainer_callback import (
|
|||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
DistributedLengthGroupedSampler,
|
DistributedLengthGroupedSampler,
|
||||||
|
DistributedSamplerWithLoop,
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
LengthGroupedSampler,
|
LengthGroupedSampler,
|
||||||
@@ -491,24 +492,10 @@ class Trainer:
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Gather the number of processes and this process index.
|
|
||||||
if self.args.parallel_mode == ParallelMode.TPU:
|
|
||||||
num_processes = xm.xrt_world_size()
|
|
||||||
process_index = xm.get_ordinal()
|
|
||||||
elif (
|
|
||||||
self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
|
||||||
or self.args.parallel_mode == ParallelMode.SAGEMAKER_DISTRIBUTED
|
|
||||||
):
|
|
||||||
num_processes = dist.get_world_size()
|
|
||||||
process_index = dist.get_rank()
|
|
||||||
else:
|
|
||||||
num_processes = 1
|
|
||||||
process_index = 0
|
|
||||||
|
|
||||||
# Build the sampler.
|
# Build the sampler.
|
||||||
if self.args.group_by_length:
|
if self.args.group_by_length:
|
||||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||||
if num_processes <= 1:
|
if self.args.world_size <= 1:
|
||||||
return LengthGroupedSampler(
|
return LengthGroupedSampler(
|
||||||
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
|
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
|
||||||
)
|
)
|
||||||
@@ -516,16 +503,26 @@ class Trainer:
|
|||||||
return DistributedLengthGroupedSampler(
|
return DistributedLengthGroupedSampler(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
self.args.train_batch_size,
|
self.args.train_batch_size,
|
||||||
num_replicas=num_processes,
|
num_replicas=self.args.world_size,
|
||||||
rank=process_index,
|
rank=self.args.process_index,
|
||||||
model_input_name=model_input_name,
|
model_input_name=model_input_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if num_processes <= 1:
|
if self.args.world_size <= 1:
|
||||||
return RandomSampler(self.train_dataset)
|
return RandomSampler(self.train_dataset)
|
||||||
|
elif self.args.parallel_mode == ParallelMode.TPU and not self.args.dataloader_drop_last:
|
||||||
|
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
||||||
|
return DistributedSamplerWithLoop(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size=self.args.per_device_train_batch_size,
|
||||||
|
num_replicas=self.args.world_size,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return DistributedSampler(self.train_dataset, num_replicas=num_processes, rank=process_index)
|
return DistributedSampler(
|
||||||
|
self.train_dataset, num_replicas=self.args.world_size, rank=self.args.process_index
|
||||||
|
)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -182,6 +182,34 @@ def torch_distributed_zero_first(local_rank: int):
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSamplerWithLoop(DistributedSampler):
|
||||||
|
"""
|
||||||
|
Like a :obj:torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the
|
||||||
|
shuffled samples to make each process have a round multiple of batch_size samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (:obj:`torch.utils.data.Dataset`):
|
||||||
|
Dataset used for sampling.
|
||||||
|
batch_size (:obj:`int`):
|
||||||
|
The batch size used with this sampler
|
||||||
|
kwargs:
|
||||||
|
All other keyword arguments passed to :obj:`DistributedSampler`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, batch_size, **kwargs):
|
||||||
|
super().__init__(dataset, **kwargs)
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = list(super().__iter__())
|
||||||
|
remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
|
||||||
|
# DistributedSampler already added samples from the beginning to make the number of samples a round multiple
|
||||||
|
# of the world size, so we skip those.
|
||||||
|
start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
|
||||||
|
indices += indices[start_remainder : start_remainder + remainder]
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
|
||||||
class SequentialDistributedSampler(Sampler):
|
class SequentialDistributedSampler(Sampler):
|
||||||
"""
|
"""
|
||||||
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
|
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
|
||||||
@@ -228,7 +256,7 @@ class SequentialDistributedSampler(Sampler):
|
|||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
|
|
||||||
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
|
||||||
if xm.xrt_world_size() <= 1:
|
if xm.xrt_world_size() <= 1:
|
||||||
return RandomSampler(dataset)
|
return RandomSampler(dataset)
|
||||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
|
|||||||
@@ -742,6 +742,20 @@ class TrainingArguments:
|
|||||||
return torch.distributed.get_world_size()
|
return torch.distributed.get_world_size()
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
@torch_required
|
||||||
|
def process_index(self):
|
||||||
|
"""
|
||||||
|
The number of processes used in parallel.
|
||||||
|
"""
|
||||||
|
if is_torch_tpu_available():
|
||||||
|
return xm.get_ordinal()
|
||||||
|
elif is_sagemaker_distributed_available():
|
||||||
|
return sm_dist.get_rank()
|
||||||
|
elif self.local_rank != -1:
|
||||||
|
return torch.distributed.get_rank()
|
||||||
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def place_model_on_device(self):
|
def place_model_on_device(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
|||||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||||
from transformers.trainer_pt_utils import (
|
from transformers.trainer_pt_utils import (
|
||||||
DistributedLengthGroupedSampler,
|
DistributedLengthGroupedSampler,
|
||||||
|
DistributedSamplerWithLoop,
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
LengthGroupedSampler,
|
LengthGroupedSampler,
|
||||||
@@ -141,3 +142,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
|||||||
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
def test_distributed_sampler_with_loop(self):
|
||||||
|
batch_size = 16
|
||||||
|
for length in [23, 64, 123]:
|
||||||
|
dataset = list(range(length))
|
||||||
|
shard1 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=0)
|
||||||
|
shard2 = DistributedSamplerWithLoop(dataset, batch_size, num_replicas=2, rank=1)
|
||||||
|
|
||||||
|
# Set seeds
|
||||||
|
shard1.set_epoch(0)
|
||||||
|
shard2.set_epoch(0)
|
||||||
|
|
||||||
|
# Sample
|
||||||
|
samples1 = list(shard1)
|
||||||
|
samples2 = list(shard2)
|
||||||
|
|
||||||
|
self.assertTrue(len(samples1) % batch_size == 0)
|
||||||
|
self.assertTrue(len(samples2) % batch_size == 0)
|
||||||
|
|
||||||
|
total = []
|
||||||
|
for sample1, sample2 in zip(samples1, samples2):
|
||||||
|
total += [sample1, sample2]
|
||||||
|
|
||||||
|
self.assertEqual(set(total[:length]), set(dataset))
|
||||||
|
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
|
||||||
|
|||||||
Reference in New Issue
Block a user