|
|
|
@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from packaging import version
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
|
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
|
|
@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
|
|
|
|
"val": self.hparams.val_max_target_length,
|
|
|
|
"val": self.hparams.val_max_target_length,
|
|
|
|
"test": self.hparams.test_max_target_length,
|
|
|
|
"test": self.hparams.test_max_target_length,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
|
|
|
|
|
|
|
|
self.hparams.sortish_sampler = False
|
|
|
|
|
|
|
|
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
|
|
|
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
|
|
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
|
|
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
|
|
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
|
|
|
|
|
|
|
|
|
|
|
@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
|
|
|
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
|
|
|
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
|
|
|
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
|
|
|
|
|
|
|
|
if self.hparams.eval_max_gen_length is not None:
|
|
|
|
|
|
|
|
self.eval_max_length = self.hparams.eval_max_gen_length
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.eval_max_length = self.model.config.max_length
|
|
|
|
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
|
|
|
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
|
|
|
|
|
|
|
|
|
|
|
def freeze_embeds(self):
|
|
|
|
def freeze_embeds(self):
|
|
|
|
@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
|
|
|
|
|
|
|
|
|
|
|
|
def _generative_step(self, batch: dict) -> dict:
|
|
|
|
def _generative_step(self, batch: dict) -> dict:
|
|
|
|
t0 = time.time()
|
|
|
|
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
batch["input_ids"],
|
|
|
|
batch["input_ids"],
|
|
|
|
attention_mask=batch["attention_mask"],
|
|
|
|
attention_mask=batch["attention_mask"],
|
|
|
|
use_cache=True,
|
|
|
|
use_cache=True,
|
|
|
|
decoder_start_token_id=self.decoder_start_token_id,
|
|
|
|
decoder_start_token_id=self.decoder_start_token_id,
|
|
|
|
num_beams=self.eval_beams,
|
|
|
|
num_beams=self.eval_beams,
|
|
|
|
|
|
|
|
max_length=self.eval_max_length,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
|
|
|
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
|
|
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
|
|
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
|
|
|
@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
|
|
|
|
dataset = self.get_dataset(type_path)
|
|
|
|
dataset = self.get_dataset(type_path)
|
|
|
|
sampler = None
|
|
|
|
sampler = None
|
|
|
|
if self.hparams.sortish_sampler and type_path == "train":
|
|
|
|
if self.hparams.sortish_sampler and type_path == "train":
|
|
|
|
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
|
|
|
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
|
|
|
|
sampler = dataset.make_sortish_sampler(batch_size)
|
|
|
|
sampler = dataset.make_sortish_sampler(batch_size)
|
|
|
|
shuffle = False
|
|
|
|
shuffle = False
|
|
|
|
|
|
|
|
|
|
|
|
@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
|
|
|
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
|
|
|
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
|
|
|
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
"--early_stopping_patience",
|
|
|
|
"--early_stopping_patience",
|
|
|
|
@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
|
|
|
|
model: SummarizationModule = SummarizationModule(args)
|
|
|
|
model: SummarizationModule = SummarizationModule(args)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
model: SummarizationModule = TranslationModule(args)
|
|
|
|
model: SummarizationModule = TranslationModule(args)
|
|
|
|
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
|
|
|
|
|
|
|
|
warnings.warn("FP16 only seems to work with torch 1.5+apex")
|
|
|
|
|
|
|
|
dataset = Path(args.data_dir).name
|
|
|
|
dataset = Path(args.data_dir).name
|
|
|
|
if (
|
|
|
|
if (
|
|
|
|
args.logger_name == "default"
|
|
|
|
args.logger_name == "default"
|
|
|
|
|