From 78387cc63eea993d931079066b463edb9b95840f Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 22 Sep 2020 18:27:28 -0400 Subject: [PATCH] [s2s] only save metrics.json from rank zero (#7331) --- examples/seq2seq/callbacks.py | 9 +++++++++ examples/seq2seq/finetune.py | 7 +------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 45785e681c..c6cd2014de 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -8,6 +8,8 @@ import torch from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities import rank_zero_only +from utils import save_json + def count_trainable_parameters(model): model_parameters = filter(lambda p: p.requires_grad, model.parameters()) @@ -72,8 +74,15 @@ class Seq2SeqLoggingCallback(pl.Callback): @rank_zero_only def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + save_json(pl_module.metrics, pl_module.metrics_save_path) return self._write_logs(trainer, pl_module, "test") + @rank_zero_only + def on_validation_end(self, trainer: pl.Trainer, pl_module): + save_json(pl_module.metrics, pl_module.metrics_save_path) + # Uncommenting this will save val generations + # return self._write_logs(trainer, pl_module, "valid") + def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False): """Saves the best model by validation ROUGE2 score.""" diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 0da637f13b..191b42928f 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -30,7 +30,6 @@ from utils import ( lmap, pickle_save, save_git_info, - save_json, use_task_specific_params, ) @@ -189,7 +188,7 @@ class SummarizationModule(BaseTransformer): losses.update(generative_metrics) all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} all_metrics["step_count"] = self.step_count - self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path + self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path preds = flatten_list([x["preds"] for x in outputs]) return { "log": all_metrics, @@ -198,10 +197,6 @@ class SummarizationModule(BaseTransformer): f"{prefix}_{self.val_metric}": metric_tensor, } - def save_metrics(self, latest_metrics, type_path) -> None: - self.metrics[type_path].append(latest_metrics) - save_json(self.metrics, self.metrics_save_path) - def calc_generative_metrics(self, preds, target) -> Dict: return calculate_rouge(preds, target)