[s2s] --eval_max_generate_length (#7018)
This commit is contained in:
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
||||
self.hparams.sortish_sampler = False
|
||||
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
|
||||
)
|
||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
||||
if self.hparams.eval_max_gen_length is not None:
|
||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||
else:
|
||||
self.eval_max_length = self.model.config.max_length
|
||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||
|
||||
def freeze_embeds(self):
|
||||
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
t0 = time.time()
|
||||
|
||||
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
use_cache=True,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
num_beams=self.eval_beams,
|
||||
max_length=self.eval_max_length,
|
||||
)
|
||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
||||
)
|
||||
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
||||
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
|
||||
warnings.warn("FP16 only seems to work with torch 1.5+apex")
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
|
||||
Reference in New Issue
Block a user