[trainer] add Trainer methods for metrics logging and saving (#10266)
* make logging and saving trainer built-in * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
||||
import collections
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@@ -1370,6 +1371,38 @@ class Trainer:
|
||||
|
||||
return metrics_copy
|
||||
|
||||
def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
"""
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
|
||||
Reference in New Issue
Block a user