[s2s] allow task_specific_params=summarization_xsum (#6923)

This commit is contained in:
Sam Shleifer
2020-09-03 11:11:40 -04:00
committed by GitHub
parent 5a318f075a
commit 39ed68d597
3 changed files with 6 additions and 4 deletions

View File

@@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback):
return self._write_logs(trainer, pl_module, "test")
def get_checkpoint_callback(output_dir, metric):
def get_checkpoint_callback(output_dir, metric, save_top_k=1):
"""Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
@@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric):
filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
save_top_k=save_top_k,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback