[Trainer] memory tracker metrics (#10225)
* memory tracker metrics * go back to eval for somewhat consistency * handle no-gpu case * deal with stackable eval calls * restore callback order * style * simplify the API * add test * docs * consistently use eval_ prefix * improve docs * Update src/transformers/trainer_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * rename method * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -588,9 +588,12 @@ def main():
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** train metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
@@ -603,17 +606,19 @@ def main():
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** val metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
@@ -628,12 +633,14 @@ def main():
|
||||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = {k: round(v, 4) for k, v in metrics.items()}
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
metrics_formatted = trainer.metrics_format(metrics)
|
||||
logger.info("***** test metrics *****")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
||||
all_metrics.update(metrics)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user