diff --git a/run_squad.py b/run_squad.py index 8a69e057e5..50d450d85a 100644 --- a/run_squad.py +++ b/run_squad.py @@ -908,7 +908,7 @@ def main(): model.eval() all_results = [] logger.info("Start evaluating") - for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"): + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results))) @@ -916,21 +916,18 @@ def main(): input_mask = input_mask.to(device) segment_ids = segment_ids.to(device) - start_logits, end_logits = model(input_ids, segment_ids, input_mask) + with torch.no_grad(): + batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) - unique_id = [int(eval_features[e.item()].unique_id) for e in example_index] - start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits] - end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits] - for idx, i in enumerate(unique_id): - s = [float(x) for x in start_logits[idx]] - e = [float(x) for x in end_logits[idx]] - all_results.append( - RawResult( - unique_id=i, - start_logits=s, - end_logits=e - ) - ) + for i, example_index in enumerate(example_indices): + start_logits = batch_start_logits[i].detach().cpu().tolist() + end_logits = batch_end_logits[i].detach().cpu().tolist() + + eval_feature = eval_features[example_index.item()] + unique_id = int(eval_feature.unique_id) + all_results.append(RawResult(unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")