[s2s] support early stopping based on loss, rather than rouge (#6927)

This commit is contained in:
Sam Shleifer
2020-09-03 17:31:35 -04:00
committed by GitHub
parent 207ed8cb78
commit e95d262f25
3 changed files with 38 additions and 21 deletions

View File

@@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback):
return self._write_logs(trainer, pl_module, "test")
def get_checkpoint_callback(output_dir, metric, save_top_k=1):
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
"""Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
elif metric == "loss":
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}",
mode="max",
mode="min" if "loss" in metric else "max",
save_top_k=save_top_k,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
@@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1):
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}",
mode="max",
monitor=f"val_{metric}", # does this need avg?
mode="min" if "loss" in metric else "max",
patience=patience,
verbose=True,
)