[s2s] only save metrics.json from rank zero (#7331)
This commit is contained in:
@@ -8,6 +8,8 @@ import torch
|
|||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.utilities import rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
from utils import save_json
|
||||||
|
|
||||||
|
|
||||||
def count_trainable_parameters(model):
|
def count_trainable_parameters(model):
|
||||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||||
@@ -72,8 +74,15 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
|||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||||
return self._write_logs(trainer, pl_module, "test")
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||||
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||||
|
# Uncommenting this will save val generations
|
||||||
|
# return self._write_logs(trainer, pl_module, "valid")
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
||||||
"""Saves the best model by validation ROUGE2 score."""
|
"""Saves the best model by validation ROUGE2 score."""
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from utils import (
|
|||||||
lmap,
|
lmap,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
save_git_info,
|
save_git_info,
|
||||||
save_json,
|
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -189,7 +188,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
losses.update(generative_metrics)
|
losses.update(generative_metrics)
|
||||||
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||||
all_metrics["step_count"] = self.step_count
|
all_metrics["step_count"] = self.step_count
|
||||||
self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path
|
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
||||||
preds = flatten_list([x["preds"] for x in outputs])
|
preds = flatten_list([x["preds"] for x in outputs])
|
||||||
return {
|
return {
|
||||||
"log": all_metrics,
|
"log": all_metrics,
|
||||||
@@ -198,10 +197,6 @@ class SummarizationModule(BaseTransformer):
|
|||||||
f"{prefix}_{self.val_metric}": metric_tensor,
|
f"{prefix}_{self.val_metric}": metric_tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
|
||||||
self.metrics[type_path].append(latest_metrics)
|
|
||||||
save_json(self.metrics, self.metrics_save_path)
|
|
||||||
|
|
||||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||||
return calculate_rouge(preds, target)
|
return calculate_rouge(preds, target)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user