From bb0a510330a147554e44376b4d2cfada8e1362b1 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 10:16:07 -0400 Subject: [PATCH] Print for debug run_squad --- run_squad_pytorch.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 86293fd194..38d8447fa4 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -818,6 +818,7 @@ def main(): logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) + logger.info("HHHHH Loading data") all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) @@ -825,14 +826,17 @@ def main(): all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) + logger.info("HHHHH Creating dataset") #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: train_sampler = DistributedSampler(train_data) + logger.info("HHHHH Dataloader") train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) + logger.info("HHHHH Starting Traing") model.train() for epoch in range(int(args.num_train_epochs)): #for input_ids, input_mask, segment_ids, label_ids in train_dataloader: @@ -846,11 +850,15 @@ def main(): start_positions = start_positions.view(-1, 1) end_positions = end_positions.view(-1, 1) - + + logger.info("HHHHH Forward") loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + logger.info("HHHHH Backward") loss.backward() + logger.info("HHHHH Loading data") optimizer.step() global_step += 1 + logger.info("Done %s steps", global_step) if args.do_predict: eval_examples = read_squad_examples( @@ -884,6 +892,7 @@ def main(): model.eval() all_results = [] + logger.info("Start evaulating") #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: for input_ids, input_mask, segment_ids, example_index in eval_dataloader: if len(all_results) % 1000 == 0: