[s2s] only save metrics.json from rank zero (#7331)
This commit is contained in:
@@ -8,6 +8,8 @@ import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
@@ -72,8 +74,15 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
# Uncommenting this will save val generations
|
||||
# return self._write_logs(trainer, pl_module, "valid")
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
||||
"""Saves the best model by validation ROUGE2 score."""
|
||||
|
||||
@@ -30,7 +30,6 @@ from utils import (
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
@@ -189,7 +188,7 @@ class SummarizationModule(BaseTransformer):
|
||||
losses.update(generative_metrics)
|
||||
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
all_metrics["step_count"] = self.step_count
|
||||
self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path
|
||||
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {
|
||||
"log": all_metrics,
|
||||
@@ -198,10 +197,6 @@ class SummarizationModule(BaseTransformer):
|
||||
f"{prefix}_{self.val_metric}": metric_tensor,
|
||||
}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_rouge(preds, target)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user