[Trainer] memory tracker metrics (#10225)

* memory tracker metrics

* go back to eval for somewhat consistency

* handle no-gpu case

* deal with stackable eval calls

* restore callback order

* style

* simplify the API

* add test

* docs

* consistently use eval_ prefix

* improve docs

* Update src/transformers/trainer_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename method

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-02-18 09:27:32 -08:00
committed by GitHub
parent d7f38c5d1d
commit 97e688bc22
7 changed files with 294 additions and 14 deletions

View File

@@ -588,9 +588,12 @@ def main():
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** train metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
all_metrics.update(metrics)
@@ -603,17 +606,19 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = {k: round(v, 4) for k, v in metrics.items()}
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** val metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
all_metrics.update(metrics)
if training_args.do_predict:
@@ -628,12 +633,14 @@ def main():
metrics = test_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
metrics = {k: round(v, 4) for k, v in metrics.items()}
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** test metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
all_metrics.update(metrics)