From b9772897ec9f54c1a83263b059bfd37acda936d5 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 31 Aug 2020 16:16:10 -0400 Subject: [PATCH] [s2s] command line args for faster val steps (#6833) --- examples/seq2seq/distillation.py | 2 +- examples/seq2seq/finetune.py | 9 +++++++-- examples/seq2seq/test_seq2seq_examples.py | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 262fae182f..7dabb2b084 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller): mode = "translation" metric_names = ["bleu"] - val_metric = "bleu" + default_val_metric = "bleu" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 90591e1b0c..cf47e32a4f 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer): mode = "summarization" loss_names = ["loss"] metric_names = ROUGE_KEYS - val_metric = "rouge2" + default_val_metric = "rouge2" def __init__(self, hparams, **kwargs): super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) @@ -110,6 +110,9 @@ class SummarizationModule(BaseTransformer): self.dataset_class = ( Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset ) + self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams + assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" + self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -301,6 +304,8 @@ class SummarizationModule(BaseTransformer): parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) + parser.add_argument("--eval_beams", type=int, default=None, required=False) + parser.add_argument("--val_metric", type=str, default=None, required=False) parser.add_argument( "--early_stopping_patience", type=int, @@ -315,7 +320,7 @@ class TranslationModule(SummarizationModule): mode = "translation" loss_names = ["loss"] metric_names = ["bleu"] - val_metric = "bleu" + default_val_metric = "bleu" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index e7c795b7c5..410c3ee0a4 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -31,6 +31,8 @@ logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { "label_smoothing": 0.2, + "eval_beams": 1, + "val_metric": None, "adafactor": True, "early_stopping_patience": 2, "logger_name": "default",