add test related code: test the best dev acc model when model is training
This commit is contained in:
@@ -169,7 +169,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
results = evaluate(args, model, tokenizer)
|
results = evaluate(args, model, tokenizer)
|
||||||
for key, value in results.items():
|
for key, value in results.items():
|
||||||
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
|
||||||
if results["eval_loss"] < best_dev_loss:
|
if results["eval_acc"] < best_dev_acc:
|
||||||
best_dev_acc = results["eval_acc"]
|
best_dev_acc = results["eval_acc"]
|
||||||
best_dev_loss = results["eval_loss"]
|
best_dev_loss = results["eval_loss"]
|
||||||
best_steps = global_step
|
best_steps = global_step
|
||||||
@@ -469,12 +469,12 @@ def main():
|
|||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
best_steps = 0
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
||||||
global_step, tr_loss, _ = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
@@ -522,7 +522,7 @@ def main():
|
|||||||
if not args.do_train:
|
if not args.do_train:
|
||||||
args.output_dir = args.model_name_or_path
|
args.output_dir = args.model_name_or_path
|
||||||
checkpoints = [args.output_dir]
|
checkpoints = [args.output_dir]
|
||||||
if args.eval_all_checkpoints:
|
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
|
||||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||||
@@ -533,7 +533,8 @@ def main():
|
|||||||
result = evaluate(args, model, tokenizer, prefix=global_step, test=True)
|
result = evaluate(args, model, tokenizer, prefix=global_step, test=True)
|
||||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||||
results.update(result)
|
results.update(result)
|
||||||
|
if best_steps:
|
||||||
|
logger.info("best steps of eval acc is the following checkpoints: %s", best_steps)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user