From 1701291ef9c927f3810d2932042e18850d115208 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Sun, 4 Nov 2018 11:54:57 +0100 Subject: [PATCH] multi-gpu cleanup --- run_classifier.py | 4 +++- run_squad.py | 13 +++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/run_classifier.py b/run_classifier.py index 683fe1bfa6..367b0c5221 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -539,9 +539,11 @@ def main(): label_ids = label_ids.to(device) loss, _ = model(input_ids, segment_ids, input_mask, label_ids) - loss = loss.mean() # sum() is to account for multi-gpu support. + if n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. tr_loss += loss.item() nb_tr_examples += input_ids.size(0) + model.zero_grad() loss.backward() optimizer.step() diff --git a/run_squad.py b/run_squad.py index 40905b3feb..4529fc6ffc 100644 --- a/run_squad.py +++ b/run_squad.py @@ -837,25 +837,19 @@ def main(): logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) - logger.info("HHHHH Loading data") all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) - #all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) - logger.info("HHHHH Creating dataset") - #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) - logger.info("HHHHH Dataloader") train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) - logger.info("HHHHH Starting Traing") model.train() for epoch in trange(int(args.num_train_epochs), desc="Epoch"): for input_ids, input_mask, segment_ids, start_positions, end_positions in tqdm(train_dataloader, @@ -863,19 +857,18 @@ def main(): input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device) - #label_ids = label_ids.to(device) start_positions = start_positions.to(device) end_positions = start_positions.to(device) start_positions = start_positions.view(-1, 1) end_positions = end_positions.view(-1, 1) - logger.info("HHHHH Forward") loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + if n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. + model.zero_grad() - logger.info("HHHHH Backward, loss: {}".format(loss)) loss.backward() - logger.info("HHHHH Loading data") optimizer.step() global_step += 1 logger.info("Done %s steps", global_step)