@@ -572,7 +572,6 @@ def main():
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
all_metrics = {}
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
if last_checkpoint is not None:
|
||||
@@ -589,13 +588,10 @@ def main():
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
if trainer.is_world_process_zero():
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
all_metrics.update(metrics)
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
@@ -608,10 +604,8 @@ def main():
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
all_metrics.update(metrics)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
@@ -626,11 +620,10 @@ def main():
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
all_metrics.update(metrics)
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
@@ -640,9 +633,6 @@ def main():
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
trainer.save_metrics("all", metrics)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ class Trainer:
|
||||
|
||||
"""
|
||||
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
|
||||
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -599,12 +599,16 @@ def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
@@ -614,16 +618,48 @@ def log_metrics(self, split, metrics):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
def save_metrics(self, split, metrics, combined=True):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
combined (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Creates combined metrics by updating ``all_results.json`` with metrics of this call
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
if combined:
|
||||
path = os.path.join(self.args.output_dir, "all_results.json")
|
||||
if os.path.exists(path):
|
||||
with open(path, "r") as f:
|
||||
all_metrics = json.load(f)
|
||||
else:
|
||||
all_metrics = {}
|
||||
|
||||
all_metrics.update(metrics)
|
||||
with open(path, "w") as f:
|
||||
json.dump(all_metrics, f, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
def save_state(self):
|
||||
"""
|
||||
Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
|
||||
|
||||
Under distributed environment this is done only for a process with rank 0.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
path = os.path.join(self.args.output_dir, "trainer_state.json")
|
||||
self.state.save_to_json(path)
|
||||
|
||||
Reference in New Issue
Block a user