[examples] better model example (#10427)

* refactors

* typo
This commit is contained in:
Stas Bekman
2021-02-26 17:01:01 -08:00
committed by GitHub
parent a85eb616f7
commit ee04b69822
3 changed files with 46 additions and 20 deletions

View File

@@ -231,7 +231,7 @@ class Trainer:
"""
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__(
self,

View File

@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
"""
Log metrics in a specially formatted way
Under distributed environment this is done only for a process with rank 0.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predictmetrics: metrics dict
"""
if not self.is_world_process_zero():
return
logger.info(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics):
def save_metrics(self, split, metrics, combined=True):
"""
Save metrics into a json file for that split, e.g. ``train_results.json``.
Under distributed environment this is done only for a process with rank 0.
Args:
split (:obj:`str`):
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
Creates combined metrics by updating ``all_results.json`` with metrics of this call
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, f"{split}_results.json")
with open(path, "w") as f:
json.dump(metrics, f, indent=4, sort_keys=True)
if combined:
path = os.path.join(self.args.output_dir, "all_results.json")
if os.path.exists(path):
with open(path, "r") as f:
all_metrics = json.load(f)
else:
all_metrics = {}
all_metrics.update(metrics)
with open(path, "w") as f:
json.dump(all_metrics, f, indent=4, sort_keys=True)
def save_state(self):
"""
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
Under distributed environment this is done only for a process with rank 0.
"""
if not self.is_world_process_zero():
return
path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)