Add speed metrics to all example scripts + template (#9260)

This commit is contained in:
Sylvain Gugger
2020-12-22 14:02:26 -05:00
committed by GitHub
parent 5b5f7dd09c
commit ab17758874
9 changed files with 118 additions and 19 deletions

View File

@@ -376,9 +376,20 @@ def main():
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
else None
)
trainer.train(model_path=model_path)
train_result = trainer.train(model_path=model_path)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
# Evaluation
results = {}
if training_args.do_eval:
@@ -393,7 +404,7 @@ def main():
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in results.items():
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")