[trainer] add Trainer methods for metrics logging and saving (#10266)
* make logging and saving trainer built-in * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,6 @@ Fine-tuning the library models for sequence to sequence.
|
|||||||
"""
|
"""
|
||||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -55,11 +54,6 @@ with FileLock(".lock") as lock:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
"""
|
||||||
@@ -596,13 +590,8 @@ def main():
|
|||||||
)
|
)
|
||||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
metrics_formatted = trainer.metrics_format(metrics)
|
trainer.log_metrics("train", metrics)
|
||||||
logger.info("***** train metrics *****")
|
trainer.save_metrics("train", metrics)
|
||||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
|
||||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
|
||||||
for key in sorted(metrics_formatted.keys()):
|
|
||||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
|
||||||
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
|
|
||||||
all_metrics.update(metrics)
|
all_metrics.update(metrics)
|
||||||
|
|
||||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||||
@@ -620,13 +609,8 @@ def main():
|
|||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
metrics_formatted = trainer.metrics_format(metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
logger.info("***** val metrics *****")
|
trainer.save_metrics("eval", metrics)
|
||||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
|
||||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
|
||||||
for key in sorted(metrics_formatted.keys()):
|
|
||||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
|
||||||
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
|
|
||||||
all_metrics.update(metrics)
|
all_metrics.update(metrics)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
@@ -643,13 +627,8 @@ def main():
|
|||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
metrics_formatted = trainer.metrics_format(metrics)
|
trainer.log_metrics("test", metrics)
|
||||||
logger.info("***** test metrics *****")
|
trainer.save_metrics("test", metrics)
|
||||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
|
||||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
|
||||||
for key in sorted(metrics_formatted.keys()):
|
|
||||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
|
||||||
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
|
|
||||||
all_metrics.update(metrics)
|
all_metrics.update(metrics)
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
@@ -662,7 +641,7 @@ def main():
|
|||||||
writer.write("\n".join(test_preds))
|
writer.write("\n".join(test_preds))
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
|
trainer.save_metrics("all", metrics)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
|||||||
import collections
|
import collections
|
||||||
import gc
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -1370,6 +1371,38 @@ class Trainer:
|
|||||||
|
|
||||||
return metrics_copy
|
return metrics_copy
|
||||||
|
|
||||||
|
def log_metrics(self, split, metrics):
|
||||||
|
"""
|
||||||
|
Log metrics in a specially formatted way
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (:obj:`str`):
|
||||||
|
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"***** {split} metrics *****")
|
||||||
|
metrics_formatted = self.metrics_format(metrics)
|
||||||
|
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||||
|
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||||
|
for key in sorted(metrics_formatted.keys()):
|
||||||
|
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||||
|
|
||||||
|
def save_metrics(self, split, metrics):
|
||||||
|
"""
|
||||||
|
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (:obj:`str`):
|
||||||
|
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics returned from train/evaluate/predict
|
||||||
|
"""
|
||||||
|
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||||
|
|
||||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||||
"""
|
"""
|
||||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||||
|
|||||||
Reference in New Issue
Block a user