From 6b850b671def8bb5a1e98f519e4bf3e01b86cf93 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 18 Dec 2020 17:09:30 -0800 Subject: [PATCH] [run_glue] add speed metrics (#9198) * add speed metrics * suggestions --- examples/text-classification/run_glue.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 941b3c84d0..afe9cd9761 100644 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -350,11 +350,24 @@ def main(): # Training if training_args.do_train: - trainer.train( + train_result = trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) + metrics = train_result.metrics + trainer.save_model() # Saves the tokenizer too for easy upload + output_train_file = os.path.join(training_args.output_dir, "train_results.txt") + if trainer.is_world_process_zero(): + with open(output_train_file, "w") as writer: + logger.info("***** Train results *****") + for key, value in sorted(metrics.items()): + logger.info(f" {key} = {value}") + writer.write(f"{key} = {value}\n") + + # Need to save the state, since Trainer.save_model saves only the tokenizer with the model + trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + # Evaluation eval_results = {} if training_args.do_eval: @@ -374,7 +387,7 @@ def main(): if trainer.is_world_process_zero(): with open(output_eval_file, "w") as writer: logger.info(f"***** Eval results {task} *****") - for key, value in eval_result.items(): + for key, value in sorted(eval_result.items()): logger.info(f" {key} = {value}") writer.write(f"{key} = {value}\n")