From 833c3a7a25c78a1ba4b55b9356f956c7b95d7f37 Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 04:00:00 -0400 Subject: [PATCH] FIX errors in loading Dataset in `run_squad_pytorch` --- run_squad_pytorch.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 64803bacc3..0d9e6f8699 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -818,9 +818,12 @@ def main(): all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) - all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) + #all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) + all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) - train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) + #train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) + train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) if args.local_rank == -1: train_sampler = RandomSampler(train_data) else: @@ -829,13 +832,16 @@ def main(): model.train() for epoch in range(int(args.num_train_epochs)): - for input_ids, input_mask, segment_ids, label_ids in train_dataloader: + #for input_ids, input_mask, segment_ids, label_ids in train_dataloader: + for input_ids, input_mask, segment_ids, start_positions, end_positions in train_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) segment_ids = segment_ids.to(device) - label_ids = label_ids.to(device) + #label_ids = label_ids.to(device) + start_positions = start_positions.to(device) + end_positions = start_positions.to(device) - loss, _ = model(input_ids, segment_ids, input_mask, label_ids) + loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss.backward() optimizer.step() global_step += 1