@@ -231,7 +231,7 @@ class Trainer:
|
||||
|
||||
"""
|
||||
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
def save_metrics(self, split, metrics, combined=True):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Creates combined metrics by updating ``all_results.json`` with metrics of this call
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
if combined:
|
||||
path = os.path.join(self.args.output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
all_metrics = json.load(f)
|
||||
else:
|
||||
all_metrics = {}
|
||||
|
||||
all_metrics.update(metrics)
|
||||
with open(path, "w") as f:
|
||||
json.dump(all_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
def save_state(self):
|
||||
"""
|
||||
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
||||
self.state.save_to_json(path)
|
||||
|
||||
Reference in New Issue
Block a user