From 5ebd8989530e874a9666852ef1bc46e0781dc8e8 Mon Sep 17 00:00:00 2001 From: elk-cloner Date: Mon, 13 Apr 2020 18:41:18 +0430 Subject: [PATCH] fix dataset shuffling for Distributed training (#huggingface#3721) (#3766) --- examples/run_language_modeling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/run_language_modeling.py b/examples/run_language_modeling.py index 2b0163d96a..5d451e7612 100644 --- a/examples/run_language_modeling.py +++ b/examples/run_language_modeling.py @@ -317,8 +317,12 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] ) set_seed(args) # Added here for reproducibility - for _ in train_iterator: + for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + + if args.local_rank != -1: + train_sampler.set_epoch(epoch) + for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training