add test related code: test the best dev acc model when model is training

This commit is contained in:
erenup
2019-08-28 15:57:29 +08:00
parent fc74132598
commit 3c7e676f8b

View File

@@ -169,7 +169,7 @@ def train(args, train_dataset, model, tokenizer):
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
if results["eval_loss"] < best_dev_loss: if results["eval_acc"] < best_dev_acc:
best_dev_acc = results["eval_acc"] best_dev_acc = results["eval_acc"]
best_dev_loss = results["eval_loss"] best_dev_loss = results["eval_loss"]
best_steps = global_step best_steps = global_step
@@ -469,12 +469,12 @@ def main():
model.to(args.device) model.to(args.device)
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
best_steps = 0
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss, _ = train(args, train_dataset, model, tokenizer) global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
@@ -522,7 +522,7 @@ def main():
if not args.do_train: if not args.do_train:
args.output_dir = args.model_name_or_path args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
@@ -533,7 +533,8 @@ def main():
result = evaluate(args, model, tokenizer, prefix=global_step, test=True) result = evaluate(args, model, tokenizer, prefix=global_step, test=True)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result) results.update(result)
if best_steps:
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
return results return results