From a432b3d466132e446e2c452a9012bb576cf9f361 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 14:39:09 +0200 Subject: [PATCH] distributed traing t_total --- examples/run_squad.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index e904187500..fb3b4b7d34 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -234,10 +234,11 @@ def main(): train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs - if args.local_rank != -1: - num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() + # if args.local_rank != -1: + # num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() # Prepare optimizer param_optimizer = list(model.named_parameters())