Remove dependency on examples/seq2seq from rag (#7395)
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -1,7 +1,20 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
|
from pytorch_lightning.utilities import rank_zero_only
|
||||||
|
|
||||||
|
from utils import save_json
|
||||||
|
|
||||||
|
|
||||||
|
def count_trainable_parameters(model):
|
||||||
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||||
|
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,3 +41,76 @@ def get_checkpoint_callback(output_dir, metric):
|
|||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
return checkpoint_callback
|
return checkpoint_callback
|
||||||
|
|
||||||
|
|
||||||
|
def get_early_stopping_callback(metric, patience):
|
||||||
|
return EarlyStopping(
|
||||||
|
monitor=f"val_{metric}", # does this need avg?
|
||||||
|
mode="min" if "loss" in metric else "max",
|
||||||
|
patience=patience,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqLoggingCallback(pl.Callback):
|
||||||
|
def on_batch_end(self, trainer, pl_module):
|
||||||
|
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||||
|
pl_module.logger.log_metrics(lrs)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def _write_logs(
|
||||||
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||||
|
) -> None:
|
||||||
|
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||||
|
metrics = trainer.callback_metrics
|
||||||
|
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||||
|
# Log results
|
||||||
|
od = Path(pl_module.hparams.output_dir)
|
||||||
|
if type_path == "test":
|
||||||
|
results_file = od / "test_results.txt"
|
||||||
|
generations_file = od / "test_generations.txt"
|
||||||
|
else:
|
||||||
|
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||||
|
# If people want this it will be easy enough to add back.
|
||||||
|
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||||
|
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||||
|
results_file.parent.mkdir(exist_ok=True)
|
||||||
|
generations_file.parent.mkdir(exist_ok=True)
|
||||||
|
with open(results_file, "a+") as writer:
|
||||||
|
for key in sorted(metrics):
|
||||||
|
if key in ["log", "progress_bar", "preds"]:
|
||||||
|
continue
|
||||||
|
val = metrics[key]
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
val = val.item()
|
||||||
|
msg = f"{key}: {val:.6f}\n"
|
||||||
|
writer.write(msg)
|
||||||
|
|
||||||
|
if not save_generations:
|
||||||
|
return
|
||||||
|
|
||||||
|
if "preds" in metrics:
|
||||||
|
content = "\n".join(metrics["preds"])
|
||||||
|
generations_file.open("w+").write(content)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
try:
|
||||||
|
npars = pl_module.model.model.num_parameters()
|
||||||
|
except AttributeError:
|
||||||
|
npars = pl_module.model.num_parameters()
|
||||||
|
|
||||||
|
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||||
|
# mp stands for million parameters
|
||||||
|
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||||
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||||
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||||
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||||
|
# Uncommenting this will save val generations
|
||||||
|
# return self._write_logs(trainer, pl_module, "valid")
|
||||||
|
|||||||
@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging
|
|||||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||||
|
|
||||||
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
||||||
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
|
from examples.rag.callbacks import ( # noqa: E402 # isort:skip
|
||||||
|
get_checkpoint_callback,
|
||||||
|
get_early_stopping_callback,
|
||||||
|
Seq2SeqLoggingCallback,
|
||||||
|
)
|
||||||
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||||
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
||||||
Seq2SeqDataset,
|
|
||||||
calculate_exact_match,
|
calculate_exact_match,
|
||||||
is_rag_model,
|
|
||||||
set_extra_model_params,
|
|
||||||
)
|
|
||||||
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
|
|
||||||
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
|
|
||||||
flatten_list,
|
flatten_list,
|
||||||
get_git_info,
|
get_git_info,
|
||||||
|
is_rag_model,
|
||||||
lmap,
|
lmap,
|
||||||
pickle_save,
|
pickle_save,
|
||||||
save_git_info,
|
save_git_info,
|
||||||
save_json,
|
save_json,
|
||||||
|
set_extra_model_params,
|
||||||
|
Seq2SeqDataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
|
|
||||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||||
dataset = self.get_dataset(type_path)
|
dataset = self.get_dataset(type_path)
|
||||||
sampler = None
|
|
||||||
if self.hparams.sortish_sampler and type_path == "train":
|
|
||||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
|
||||||
sampler = dataset.make_sortish_sampler(batch_size)
|
|
||||||
shuffle = False
|
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
sampler=sampler,
|
|
||||||
)
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer):
|
|||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated, sequences shorter will be padded.",
|
"than this will be truncated, sequences shorter will be padded.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
|
||||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
|
import itertools
|
||||||
|
import json
|
||||||
import linecache
|
import linecache
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
import re
|
import re
|
||||||
|
import socket
|
||||||
import string
|
import string
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Callable, Dict, Iterable, List
|
||||||
|
|
||||||
|
import git
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from examples.seq2seq.utils import SortishSampler, trim_batch
|
|
||||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_batch(
|
||||||
|
input_ids,
|
||||||
|
pad_token_id,
|
||||||
|
attention_mask=None,
|
||||||
|
):
|
||||||
|
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||||
|
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||||
|
if attention_mask is None:
|
||||||
|
return input_ids[:, keep_column_mask]
|
||||||
|
else:
|
||||||
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqDataset(Dataset):
|
class Seq2SeqDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset):
|
|||||||
}
|
}
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def make_sortish_sampler(self, batch_size):
|
|
||||||
return SortishSampler(self.src_lens, batch_size)
|
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_list(summary_ids: List[List]):
|
||||||
|
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||||
|
|
||||||
|
|
||||||
|
def save_git_info(folder_path: str) -> None:
|
||||||
|
"""Save git information to output_dir/git_log.json"""
|
||||||
|
repo_infos = get_git_info()
|
||||||
|
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path):
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_info():
|
||||||
|
repo = git.Repo(search_parent_directories=True)
|
||||||
|
repo_infos = {
|
||||||
|
"repo_id": str(repo),
|
||||||
|
"repo_sha": str(repo.head.object.hexsha),
|
||||||
|
"repo_branch": str(repo.active_branch),
|
||||||
|
"hostname": str(socket.gethostname()),
|
||||||
|
}
|
||||||
|
return repo_infos
|
||||||
|
|
||||||
|
|
||||||
|
def lmap(f: Callable, x: Iterable) -> List:
|
||||||
|
"""list(map(f, x))"""
|
||||||
|
return list(map(f, x))
|
||||||
|
|
||||||
|
|
||||||
|
def pickle_save(obj, path):
|
||||||
|
"""pickle.dump(obj, path)"""
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
return pickle.dump(obj, f)
|
||||||
|
|
||||||
|
|
||||||
def normalize_answer(s):
|
def normalize_answer(s):
|
||||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user