update seed_worker to set seed based on worker_id and rank (#37980)

* update seed_worker to set seed based on worker_id and rank

* test case

* set output_dir as remove tmp dir
This commit is contained in:
Shiyu
2025-05-12 23:59:16 +08:00
committed by GitHub
parent e387821a96
commit a63cb7578e
3 changed files with 96 additions and 3 deletions

View File

@@ -32,6 +32,7 @@ import tempfile
import time import time
import warnings import warnings
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@@ -1028,7 +1029,9 @@ class Trainer:
if not isinstance(train_dataset, torch.utils.data.IterableDataset): if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

View File

@@ -49,11 +49,12 @@ if is_torch_available():
import torch import torch
def seed_worker(_): def seed_worker(worker_id: int, num_workers: int, rank: int):
""" """
Helper function to set worker seed during Dataloader initialization. Helper function to set worker seed during Dataloader initialization.
""" """
worker_seed = torch.initial_seed() % 2**32 init_seed = torch.initial_seed() % 2**32
worker_seed = num_workers * rank + init_seed
set_seed(worker_seed) set_seed(worker_seed)

View File

@@ -0,0 +1,89 @@
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
def gather_from_all_gpus(tensor, world_size):
# Prepare a list to gather tensors from all processes
gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(gather_list, tensor)
return gather_list # List of tensors from all ranks
class DummyDataset(Dataset):
def __init__(self):
self.length = 64
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
x = random.random()
y = np.random.random()
z = torch.rand([]).item()
return {"x": torch.tensor([x, y, z])}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x):
local_tensor = torch.tensor(x, device="cuda")
gathered = gather_from_all_gpus(local_tensor, dist.get_world_size())
assert not all(torch.allclose(t, gathered[0]) for t in gathered[1:])
y = self.fc(x)
return (y.mean(), y)
class TestTrainerDistributedWorkerSeed(TestCasePlus):
@require_torch_multi_gpu
def test_trainer(self):
device_count = torch.cuda.device_count()
output_dir = self.get_auto_remove_tmp_dir()
distributed_args = f"""--nproc_per_node={device_count}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed_worker_seed.py
""".split()
args = f"--output_dir {output_dir}".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
def run_distributed_training(training_args):
set_seed(42)
model = DummyModel()
dataset = DummyDataset()
training_args.max_steps = 10
# dataloader_num_workers must be > 0 to enable worker_init_fn
training_args.dataloader_num_workers = 2
trainer = Trainer(
model,
training_args,
train_dataset=dataset,
)
trainer.train()
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
run_distributed_training(training_args)