From 0919389d9aa03c19bee2ae9bc9922ff15ec8381b Mon Sep 17 00:00:00 2001 From: William Tambellini Date: Thu, 17 Oct 2019 14:41:04 -0700 Subject: [PATCH] Add speed log to examples/run_squad.py Add a speed estimate log (time per example) for evaluation to examples/run_squad.py --- examples/run_squad.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/run_squad.py b/examples/run_squad.py index 71c656a13d..f64ed13ae8 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -22,6 +22,7 @@ import logging import os import random import glob +import timeit import numpy as np import torch @@ -218,6 +219,7 @@ def evaluate(args, model, tokenizer, prefix=""): logger.info(" Num examples = %d", len(dataset)) logger.info(" Batch size = %d", args.eval_batch_size) all_results = [] + start_time = timeit.default_timer() for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) @@ -250,6 +252,9 @@ def evaluate(args, model, tokenizer, prefix=""): end_logits = to_list(outputs[1][i])) all_results.append(result) + evalTime = timeit.default_timer() - start_time + logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) + # Compute predictions output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))