From 8f07f5c44bf33f10b0075ce770b19de96ab389c0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 26 Nov 2020 20:12:01 +0100 Subject: [PATCH] Revert "finetune.py: specifying generation min_length (#8478)" (#8805) This reverts commit 5aa361f3e56de0f65720f291bb3975bfc98f2837. --- examples/seq2seq/finetune.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 364c5de6e9..156b4695a6 100755 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -113,10 +113,6 @@ class SummarizationModule(BaseTransformer): self.eval_max_length = self.hparams.eval_max_gen_length else: self.eval_max_length = self.model.config.max_length - if self.hparams.eval_min_gen_length is not None: - self.eval_min_length = self.hparams.eval_min_gen_length - else: - self.eval_min_length = self.model.config.min_length self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: @@ -223,7 +219,6 @@ class SummarizationModule(BaseTransformer): decoder_start_token_id=self.decoder_start_token_id, num_beams=self.eval_beams, max_length=self.eval_max_length, - min_length=self.eval_min_length, ) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) @@ -351,7 +346,6 @@ class SummarizationModule(BaseTransformer): "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] ) parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens") - parser.add_argument("--eval_min_gen_length", type=int, default=None, help="never generate shorter than n tokens") parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument( "--early_stopping_patience",