diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 72c36cca65..2069dde739 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3806,7 +3806,9 @@ class Trainer: # create accelerator object self.accelerator = Accelerator( - deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin + dispatch_batches=self.args.dispatch_batches, + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_plugin=gradient_accumulation_plugin, ) # deepspeed and accelerate flags covering both trainer args and accelerate launcher diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2efafaea77..313caf47e9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1200,6 +1200,15 @@ class TrainingArguments: }, ) + dispatch_batches: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + "and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + "underlying dataset is an `IterableDataset`, `False` otherwise." + }, + ) + def __post_init__(self): # expand paths, if not os.makedirs("~/bar") will make directory # in the current directory instead of the actual home diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 8711a3970f..5a7734b8ba 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -14,6 +14,8 @@ from typing import Dict +import numpy as np + from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available from transformers.testing_utils import ( TestCasePlus, @@ -33,7 +35,7 @@ logger = logging.get_logger(__name__) if is_torch_available(): import torch from torch import nn - from torch.utils.data import Dataset + from torch.utils.data import Dataset, IterableDataset from transformers import Trainer @@ -63,6 +65,56 @@ if is_torch_available(): else: return input_ids + class RegressionModel(nn.Module): + def __init__(self, a=0, b=0, double_output=False): + super().__init__() + self.a = nn.Parameter(torch.tensor(a).float()) + self.b = nn.Parameter(torch.tensor(b).float()) + self.double_output = double_output + self.config = None + + def forward(self, input_x, labels=None, **kwargs): + y = input_x * self.a + self.b + if labels is None: + return (y, y) if self.double_output else (y,) + loss = nn.functional.mse_loss(y, labels) + return (loss, y, y) if self.double_output else (loss, y) + + class SampleIterableDataset(IterableDataset): + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): + self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names) + + def __iter__(self): + for i in range(len(self.dataset)): + yield self.dataset[i] + + class FiniteIterableDataset(SampleIterableDataset): + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): + super().__init__(a, b, length, seed, label_names) + self.current_sample = 0 + + def __iter__(self): + while self.current_sample < len(self.dataset): + yield self.dataset[self.current_sample] + self.current_sample += 1 + + class RegressionDataset: + def __init__(self, a=2, b=3, length=64, seed=42, label_names=None): + np.random.seed(seed) + self.label_names = ["labels"] if label_names is None else label_names + self.length = length + self.x = np.random.normal(size=(length,)).astype(np.float32) + self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names] + self.ys = [y.astype(np.float32) for y in self.ys] + + def __len__(self): + return self.length + + def __getitem__(self, i): + result = {name: y[i] for name, y in zip(self.label_names, self.ys)} + result["input_x"] = self.x[i] + return result + class TestTrainerDistributedNeuronCore(TestCasePlus): @require_torch_neuroncore @@ -168,3 +220,14 @@ if __name__ == "__main__": exit(1) trainer.args.eval_accumulation_steps = None + + # Check that `dispatch_batches=False` will work on a finite iterable dataset + + train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1) + + model = RegressionModel() + training_args.per_device_train_batch_size = 1 + training_args.max_steps = 1 + training_args.dispatch_batches = False + trainer = Trainer(model, training_args, train_dataset=train_dataset) + trainer.train()