FIX errors in loading eval Dataset in run_squad_pytorch
This commit is contained in:
@@ -865,10 +865,11 @@ def main():
|
|||||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
#all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
|
|
||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
|
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
|
||||||
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
eval_sampler = SequentialSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
else:
|
else:
|
||||||
@@ -877,7 +878,8 @@ def main():
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
all_results = []
|
all_results = []
|
||||||
for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
|
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
|
||||||
|
for input_ids, input_mask, segment_ids, example_index in eval_dataloader:
|
||||||
if len(all_results) % 1000 == 0:
|
if len(all_results) % 1000 == 0:
|
||||||
logger.info("Processing example: %d" % (len(all_results)))
|
logger.info("Processing example: %d" % (len(all_results)))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user