From 39ed68d597c9aa21c9188498289c653acd3fce45 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 3 Sep 2020 11:11:40 -0400 Subject: [PATCH] [s2s] allow task_specific_params=summarization_xsum (#6923) --- examples/seq2seq/callbacks.py | 4 ++-- examples/seq2seq/finetune.py | 5 +++-- examples/seq2seq/test_seq2seq_examples.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 39a9cbc9f1..5942001b3d 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -75,7 +75,7 @@ class Seq2SeqLoggingCallback(pl.Callback): return self._write_logs(trainer, pl_module, "test") -def get_checkpoint_callback(output_dir, metric): +def get_checkpoint_callback(output_dir, metric, save_top_k=1): """Saves the best model by validation ROUGE2 score.""" if metric == "rouge2": exp = "{val_avg_rouge2:.4f}-{step_count}" @@ -90,7 +90,7 @@ def get_checkpoint_callback(output_dir, metric): filepath=os.path.join(output_dir, exp), monitor=f"val_{metric}", mode="max", - save_top_k=1, + save_top_k=save_top_k, period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index cf47e32a4f..16e407fa1e 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -306,6 +306,7 @@ class SummarizationModule(BaseTransformer): parser.add_argument("--tgt_lang", type=str, default="", required=False) parser.add_argument("--eval_beams", type=int, default=None, required=False) parser.add_argument("--val_metric", type=str, default=None, required=False) + parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument( "--early_stopping_patience", type=int, @@ -336,7 +337,7 @@ def main(args, model=None) -> SummarizationModule: if len(os.listdir(args.output_dir)) > 3 and args.do_train: raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if model is None: - if args.task == "summarization": + if "summarization" in args.task: model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) @@ -368,7 +369,7 @@ def main(args, model=None) -> SummarizationModule: model, args, logging_callback=Seq2SeqLoggingCallback(), - checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), + checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k), early_stopping_callback=es_callback, logger=logger, # TODO: early stopping callback seems messed up diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 89fafa9346..2ecc7b8883 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -34,6 +34,7 @@ CHEAP_ARGS = { "label_smoothing": 0.2, "eval_beams": 1, "val_metric": None, + "save_top_k": 1, "adafactor": True, "early_stopping_patience": 2, "logger_name": "default",