From 3c7e676f8b3cafd501a2616030f4bd1e512212f9 Mon Sep 17 00:00:00 2001 From: erenup Date: Wed, 28 Aug 2019 15:57:29 +0800 Subject: [PATCH] add test related code: test the best dev acc model when model is training --- examples/single_model_scripts/run_multiple_choice.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/single_model_scripts/run_multiple_choice.py b/examples/single_model_scripts/run_multiple_choice.py index 7f32c6cf7d..9784dfe94d 100644 --- a/examples/single_model_scripts/run_multiple_choice.py +++ b/examples/single_model_scripts/run_multiple_choice.py @@ -169,7 +169,7 @@ def train(args, train_dataset, model, tokenizer): results = evaluate(args, model, tokenizer) for key, value in results.items(): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) - if results["eval_loss"] < best_dev_loss: + if results["eval_acc"] < best_dev_acc: best_dev_acc = results["eval_acc"] best_dev_loss = results["eval_loss"] best_steps = global_step @@ -469,12 +469,12 @@ def main(): model.to(args.device) logger.info("Training/evaluation parameters %s", args) - + best_steps = 0 # Training if args.do_train: train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) - global_step, tr_loss, _ = train(args, train_dataset, model, tokenizer) + global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) @@ -522,7 +522,7 @@ def main(): if not args.do_train: args.output_dir = args.model_name_or_path checkpoints = [args.output_dir] - if args.eval_all_checkpoints: + if args.eval_all_checkpoints: #can not use this to do test!! just for different paras checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) @@ -533,7 +533,8 @@ def main(): result = evaluate(args, model, tokenizer, prefix=global_step, test=True) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) - + if best_steps: + logger.info("best steps of eval acc is the following checkpoints: %s", best_steps) return results