From a63cb7578e13ca59968d715bd37ecb3d456f3ac6 Mon Sep 17 00:00:00 2001 From: Shiyu Date: Mon, 12 May 2025 23:59:16 +0800 Subject: [PATCH] update seed_worker to set seed based on worker_id and rank (#37980) * update seed_worker to set seed based on worker_id and rank * test case * set output_dir as remove tmp dir --- src/transformers/trainer.py | 5 +- src/transformers/trainer_utils.py | 5 +- .../test_trainer_distributed_worker_seed.py | 89 +++++++++++++++++++ 3 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 tests/trainer/test_trainer_distributed_worker_seed.py diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ccbb4ebe44..5886146002 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -32,6 +32,7 @@ import tempfile import time import warnings from collections.abc import Mapping +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -1028,7 +1029,9 @@ class Trainer: if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 7b2d5c3432..5556c1d43e 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -49,11 +49,12 @@ if is_torch_available(): import torch -def seed_worker(_): +def seed_worker(worker_id: int, num_workers: int, rank: int): """ Helper function to set worker seed during Dataloader initialization. """ - worker_seed = torch.initial_seed() % 2**32 + init_seed = torch.initial_seed() % 2**32 + worker_seed = num_workers * rank + init_seed set_seed(worker_seed) diff --git a/tests/trainer/test_trainer_distributed_worker_seed.py b/tests/trainer/test_trainer_distributed_worker_seed.py new file mode 100644 index 0000000000..f4fececf10 --- /dev/null +++ b/tests/trainer/test_trainer_distributed_worker_seed.py @@ -0,0 +1,89 @@ +import random + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.utils.data import Dataset + +from transformers import ( + HfArgumentParser, + Trainer, + TrainingArguments, + set_seed, +) +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_torch_multi_gpu, +) + + +def gather_from_all_gpus(tensor, world_size): + # Prepare a list to gather tensors from all processes + gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] + dist.all_gather(gather_list, tensor) + return gather_list # List of tensors from all ranks + + +class DummyDataset(Dataset): + def __init__(self): + self.length = 64 + + def __len__(self): + return self.length + + def __getitem__(self, i) -> int: + x = random.random() + y = np.random.random() + z = torch.rand([]).item() + return {"x": torch.tensor([x, y, z])} + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(3, 1) + + def forward(self, x): + local_tensor = torch.tensor(x, device="cuda") + gathered = gather_from_all_gpus(local_tensor, dist.get_world_size()) + assert not all(torch.allclose(t, gathered[0]) for t in gathered[1:]) + y = self.fc(x) + return (y.mean(), y) + + +class TestTrainerDistributedWorkerSeed(TestCasePlus): + @require_torch_multi_gpu + def test_trainer(self): + device_count = torch.cuda.device_count() + output_dir = self.get_auto_remove_tmp_dir() + distributed_args = f"""--nproc_per_node={device_count} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_trainer_distributed_worker_seed.py + """.split() + args = f"--output_dir {output_dir}".split() + cmd = ["torchrun"] + distributed_args + args + execute_subprocess_async(cmd, env=self.get_env()) + + +def run_distributed_training(training_args): + set_seed(42) + model = DummyModel() + dataset = DummyDataset() + training_args.max_steps = 10 + # dataloader_num_workers must be > 0 to enable worker_init_fn + training_args.dataloader_num_workers = 2 + trainer = Trainer( + model, + training_args, + train_dataset=dataset, + ) + trainer.train() + + +if __name__ == "__main__": + parser = HfArgumentParser((TrainingArguments,)) + training_args = parser.parse_args_into_dataclasses()[0] + run_distributed_training(training_args)