From fe326bd5cf1aa4ec65286e6500070f5440420a82 Mon Sep 17 00:00:00 2001 From: Ola Piktus Date: Fri, 25 Sep 2020 17:20:49 +0100 Subject: [PATCH] Remove dependency on examples/seq2seq from rag (#7395) Co-authored-by: Your Name --- examples/rag/callbacks.py | 88 ++++++++++++++++++++++++++++++++++++++- examples/rag/finetune.py | 22 ++++------ examples/rag/utils.py | 67 ++++++++++++++++++++++++++--- 3 files changed, 157 insertions(+), 20 deletions(-) diff --git a/examples/rag/callbacks.py b/examples/rag/callbacks.py index 222db114dd..099cf2bbdf 100644 --- a/examples/rag/callbacks.py +++ b/examples/rag/callbacks.py @@ -1,7 +1,20 @@ import logging import os +from pathlib import Path -from pytorch_lightning.callbacks import ModelCheckpoint +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + +from utils import save_json + + +def count_trainable_parameters(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return params logger = logging.getLogger(__name__) @@ -28,3 +41,76 @@ def get_checkpoint_callback(output_dir, metric): period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback + + +def get_early_stopping_callback(metric, patience): + return EarlyStopping( + monitor=f"val_{metric}", # does this need avg? + mode="min" if "loss" in metric else "max", + patience=patience, + verbose=True, + ) + + +class Seq2SeqLoggingCallback(pl.Callback): + def on_batch_end(self, trainer, pl_module): + lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} + pl_module.logger.log_metrics(lrs) + + @rank_zero_only + def _write_logs( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True + ) -> None: + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") + metrics = trainer.callback_metrics + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) + # Log results + od = Path(pl_module.hparams.output_dir) + if type_path == "test": + results_file = od / "test_results.txt" + generations_file = od / "test_generations.txt" + else: + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json + # If people want this it will be easy enough to add back. + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" + results_file.parent.mkdir(exist_ok=True) + generations_file.parent.mkdir(exist_ok=True) + with open(results_file, "a+") as writer: + for key in sorted(metrics): + if key in ["log", "progress_bar", "preds"]: + continue + val = metrics[key] + if isinstance(val, torch.Tensor): + val = val.item() + msg = f"{key}: {val:.6f}\n" + writer.write(msg) + + if not save_generations: + return + + if "preds" in metrics: + content = "\n".join(metrics["preds"]) + generations_file.open("w+").write(content) + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + try: + npars = pl_module.model.model.num_parameters() + except AttributeError: + npars = pl_module.model.num_parameters() + + n_trainable_pars = count_trainable_parameters(pl_module) + # mp stands for million parameters + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) + + @rank_zero_only + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + save_json(pl_module.metrics, pl_module.metrics_save_path) + return self._write_logs(trainer, pl_module, "test") + + @rank_zero_only + def on_validation_end(self, trainer: pl.Trainer, pl_module): + save_json(pl_module.metrics, pl_module.metrics_save_path) + # Uncommenting this will save val generations + # return self._write_logs(trainer, pl_module, "valid") diff --git a/examples/rag/finetune.py b/examples/rag/finetune.py index 4b39724875..c76045fc3d 100644 --- a/examples/rag/finetune.py +++ b/examples/rag/finetune.py @@ -34,22 +34,23 @@ from transformers import logging as transformers_logging sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip -from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip +from examples.rag.callbacks import ( # noqa: E402 # isort:skip + get_checkpoint_callback, + get_early_stopping_callback, + Seq2SeqLoggingCallback, +) from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip from examples.rag.utils import ( # noqa: E402 # isort:skip - Seq2SeqDataset, calculate_exact_match, - is_rag_model, - set_extra_model_params, -) -from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip -from examples.seq2seq.utils import ( # noqa: E402 # isort:skip flatten_list, get_git_info, + is_rag_model, lmap, pickle_save, save_git_info, save_json, + set_extra_model_params, + Seq2SeqDataset, ) logging.basicConfig(level=logging.INFO) @@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer): def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: dataset = self.get_dataset(type_path) - sampler = None - if self.hparams.sortish_sampler and type_path == "train": - assert self.hparams.gpus <= 1 # TODO: assert earlier - sampler = dataset.make_sortish_sampler(batch_size) - shuffle = False dataloader = DataLoader( dataset, @@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer): collate_fn=dataset.collate_fn, shuffle=shuffle, num_workers=self.num_workers, - sampler=sampler, ) return dataloader @@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer): help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) - parser.add_argument("--sortish_sampler", action="store_true", default=False) parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.") diff --git a/examples/rag/utils.py b/examples/rag/utils.py index e17fddb73f..7bf5d7e35e 100644 --- a/examples/rag/utils.py +++ b/examples/rag/utils.py @@ -1,15 +1,20 @@ +import itertools +import json import linecache +import os +import pickle import re +import socket import string from collections import Counter from logging import getLogger from pathlib import Path -from typing import Dict, List +from typing import Callable, Dict, Iterable, List +import git import torch from torch.utils.data import Dataset -from examples.seq2seq.utils import SortishSampler, trim_batch from transformers import BartTokenizer, RagTokenizer, T5Tokenizer @@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru ) +def trim_batch( + input_ids, + pad_token_id, + attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) + + class Seq2SeqDataset(Dataset): def __init__( self, @@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset): } return batch - def make_sortish_sampler(self, batch_size): - return SortishSampler(self.src_lens, batch_size) - logger = getLogger(__name__) +def flatten_list(summary_ids: List[List]): + return [x for x in itertools.chain.from_iterable(summary_ids)] + + +def save_git_info(folder_path: str) -> None: + """Save git information to output_dir/git_log.json""" + repo_infos = get_git_info() + save_json(repo_infos, os.path.join(folder_path, "git_log.json")) + + +def save_json(content, path, indent=4, **json_dump_kwargs): + with open(path, "w") as f: + json.dump(content, f, indent=indent, **json_dump_kwargs) + + +def load_json(path): + with open(path) as f: + return json.load(f) + + +def get_git_info(): + repo = git.Repo(search_parent_directories=True) + repo_infos = { + "repo_id": str(repo), + "repo_sha": str(repo.head.object.hexsha), + "repo_branch": str(repo.active_branch), + "hostname": str(socket.gethostname()), + } + return repo_infos + + +def lmap(f: Callable, x: Iterable) -> List: + """list(map(f, x))""" + return list(map(f, x)) + + +def pickle_save(obj, path): + """pickle.dump(obj, path)""" + with open(path, "wb") as f: + return pickle.dump(obj, f) + + def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace."""