Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger
2020-12-18 15:10:39 -05:00
committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
6 changed files with 76 additions and 49 deletions

View File

@@ -16,7 +16,6 @@
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from typing import Optional
@@ -120,30 +119,6 @@ class DataTrainingArguments:
)
def speed_metrics(split, start_time, num_samples):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this
function should be run immediately after the operation to be measured has completed.
Args:
- split: one of train, val, test
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {}
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
result[f"{split}_runtime"] = round(runtime, 4)
result[f"{split}_n_ojbs"] = num_samples
return result
def handle_metrics(split, metrics, output_dir):
"""
Log and save metrics
@@ -155,8 +130,8 @@ def handle_metrics(split, metrics, output_dir):
"""
logger.info(f"***** {split} metrics *****")
for key, value in metrics.items():
logger.info(f" {key} = {value}")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
@@ -311,11 +286,11 @@ def main():
if training_args.do_train:
logger.info("*** Train ***")
start_time = time.time()
trainer.train(
train_result = trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
metrics = speed_metrics("train", start_time, data_args.n_train)
metrics = train_result.metrics
metrics["train_n_objs"] = data_args.n_train
trainer.save_model() # this also saves the tokenizer
@@ -334,9 +309,8 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
start_time = time.time()
metrics = trainer.evaluate(metric_key_prefix="val")
metrics.update(speed_metrics("val", start_time, data_args.n_val))
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)
if trainer.is_world_process_zero():
@@ -347,10 +321,9 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")
start_time = time.time()
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
metrics = test_output.metrics
metrics.update(speed_metrics("test", start_time, data_args.n_test))
metrics["test_n_objs"] = data_args.n_test
if trainer.is_world_process_zero():
metrics["test_loss"] = round(metrics["test_loss"], 4)

View File

@@ -97,9 +97,7 @@ class ExamplesTests(TestCasePlus):
with patch.object(sys, "argv", testargs):
result = run_glue.main()
del result["eval_loss"]
for value in result.values():
self.assertGreaterEqual(value, 0.75)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
@require_torch_non_multi_gpu_but_fix_me
def test_run_clm(self):