diff --git a/examples/run_language_modeling.py b/examples/run_language_modeling.py index 2b0163d96a..5d451e7612 100644 --- a/examples/run_language_modeling.py +++ b/examples/run_language_modeling.py @@ -317,8 +317,12 @@ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedToke epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] ) set_seed(args) # Added here for reproducibility - for _ in train_iterator: + for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + + if args.local_rank != -1: + train_sampler.set_epoch(epoch) + for step, batch in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training