From 47e9aea0fe9b8b83459f0744fa47bda6c6d4a699 Mon Sep 17 00:00:00 2001 From: erenup Date: Sun, 18 Aug 2019 17:00:53 +0800 Subject: [PATCH] add args info to evaluate_result.txt --- examples/single_model_scripts/run_multiple_choice.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/single_model_scripts/run_multiple_choice.py b/examples/single_model_scripts/run_multiple_choice.py index 4d42a73f99..b0879b48b5 100644 --- a/examples/single_model_scripts/run_multiple_choice.py +++ b/examples/single_model_scripts/run_multiple_choice.py @@ -249,12 +249,18 @@ def evaluate(args, model, tokenizer, prefix=""): results.update(result) output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") + with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) + writer.write("model =%s\n" % str(args.model_name_or_path)) + writer.write("total batch size=%d\n" % (args.train_batch_size * args.gradient_accumulation_steps * + (torch.distributed.get_world_size() if args.local_rank != -1 else 1))) + writer.write("train num epochs=%d\n" % args.num_train_epochs) + writer.write("fp16 =%s\n" % args.fp16) + writer.write("max seq length =%d\n" % args.max_seq_length) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) - return results