[s2s] only save metrics.json from rank zero (#7331)

This commit is contained in:
Sam Shleifer
2020-09-22 18:27:28 -04:00
committed by GitHub
parent e53138a1b9
commit 78387cc63e
2 changed files with 10 additions and 6 deletions

View File

@@ -8,6 +8,8 @@ import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils import save_json
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
@@ -72,8 +74,15 @@ class Seq2SeqLoggingCallback(pl.Callback):
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
save_json(pl_module.metrics, pl_module.metrics_save_path)
return self._write_logs(trainer, pl_module, "test")
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module):
save_json(pl_module.metrics, pl_module.metrics_save_path)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
"""Saves the best model by validation ROUGE2 score."""