From 043f9f51f943ea95e5052b7923769ba28ac1b1bd Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 17 Jun 2020 13:51:34 -0400 Subject: [PATCH] [examples] SummarizationModule improvements (#4951) --- examples/lightning_base.py | 148 ++++-- examples/requirements.txt | 3 +- examples/summarization/README.md | 75 ++- examples/summarization/callbacks.py | 85 ++++ examples/summarization/distillation.py | 448 ++++++++++++++++++ examples/summarization/evaluate_cnn.py | 100 ---- examples/summarization/finetune.py | 295 +++++++++--- examples/summarization/finetune.sh | 23 + examples/summarization/finetune_bart.sh | 18 - .../summarization/initialization_utils.py | 20 + examples/summarization/run_distiller.sh | 12 + examples/summarization/run_eval.py | 78 +++ .../test_summarization_examples.py | 277 ++++++++--- examples/summarization/utils.py | 229 ++++++++- src/transformers/modeling_bart.py | 2 +- 15 files changed, 1465 insertions(+), 348 deletions(-) create mode 100644 examples/summarization/callbacks.py create mode 100644 examples/summarization/distillation.py delete mode 100644 examples/summarization/evaluate_cnn.py create mode 100755 examples/summarization/finetune.sh delete mode 100644 examples/summarization/finetune_bart.sh create mode 100644 examples/summarization/initialization_utils.py create mode 100755 examples/summarization/run_distiller.sh create mode 100644 examples/summarization/run_eval.py diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 480b69f268..39604efae3 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -2,6 +2,8 @@ import argparse import logging import os import random +from pathlib import Path +from typing import Any, Dict import numpy as np import pytorch_lightning as pl @@ -13,10 +15,13 @@ from transformers import ( AutoModel, AutoModelForPreTraining, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelWithLMHead, AutoTokenizer, + PretrainedConfig, + PreTrainedTokenizer, get_linear_schedule_with_warmup, ) @@ -31,6 +36,8 @@ MODEL_MODES = { "pretraining": AutoModelForPreTraining, "token-classification": AutoModelForTokenClassification, "language-modeling": AutoModelWithLMHead, + "summarization": AutoModelForSeq2SeqLM, + "translation": AutoModelForSeq2SeqLM, } @@ -38,33 +45,59 @@ def set_seed(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) - if args.n_gpu > 0: + if args.gpus > 0: torch.cuda.manual_seed_all(args.seed) class BaseTransformer(pl.LightningModule): - def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs): + def __init__( + self, + hparams: argparse.Namespace, + num_labels=None, + mode="base", + config=None, + tokenizer=None, + model=None, + **config_kwargs + ): "Initialize a model." super().__init__() self.hparams = hparams + self.step_count = 0 + self.tfmr_ckpts = {} + self.output_dir = Path(self.hparams.output_dir) cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None - self.config = AutoConfig.from_pretrained( - self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, - **({"num_labels": num_labels} if num_labels is not None else {}), - cache_dir=cache_dir, - **config_kwargs, - ) - self.tokenizer = AutoTokenizer.from_pretrained( - self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, - cache_dir=cache_dir, - ) - self.model = MODEL_MODES[mode].from_pretrained( - self.hparams.model_name_or_path, - from_tf=bool(".ckpt" in self.hparams.model_name_or_path), - config=self.config, - cache_dir=cache_dir, - ) + if config is None: + self.config = AutoConfig.from_pretrained( + self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, + **({"num_labels": num_labels} if num_labels is not None else {}), + cache_dir=cache_dir, + **config_kwargs, + ) + else: + self.config: PretrainedConfig = config + if tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained( + self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, + cache_dir=cache_dir, + ) + else: + self.tokenizer: PreTrainedTokenizer = tokenizer + if model is None: + self.model_type = MODEL_MODES[mode] + self.model = self.model_type.from_pretrained( + self.hparams.model_name_or_path, + from_tf=bool(".ckpt" in self.hparams.model_name_or_path), + config=self.config, + cache_dir=cache_dir, + ) + else: + self.model_type = None + self.model = model + + def load_hf_checkpoint(self, *args, **kwargs): + self.model = self.model_type.from_pretrained(*args, **kwargs) def is_logger(self): return self.trainer.proc_rank <= 0 @@ -138,6 +171,15 @@ class BaseTransformer(pl.LightningModule): ), ) + @pl.utilities.rank_zero_only + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + save_path = self.output_dir.joinpath("best_tfmr") + save_path.mkdir(exist_ok=True) + self.model.config.save_step = self.step_count + self.model.save_pretrained(save_path) + self.tokenizer.save_pretrained(save_path) + self.tfmr_ckpts[self.step_count] = save_path + @staticmethod def add_model_specific_args(parser, root_dir): parser.add_argument( @@ -152,7 +194,7 @@ class BaseTransformer(pl.LightningModule): ) parser.add_argument( "--tokenizer_name", - default="", + default=None, type=str, help="Pretrained tokenizer name or path if not the same as model_name", ) @@ -165,7 +207,7 @@ class BaseTransformer(pl.LightningModule): parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") - parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.") parser.add_argument( "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." ) @@ -199,7 +241,8 @@ class LoggingCallback(pl.Callback): writer.write("{} = {}\n".format(key, str(metrics[key]))) -def add_generic_args(parser, root_dir): +def add_generic_args(parser, root_dir) -> None: + # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser) parser.add_argument( "--output_dir", default=None, @@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir): help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) - - parser.add_argument("--n_gpu", type=int, default=1) + parser.add_argument("--fast_dev_run", action="store_true") + parser.add_argument("--gpus", type=int, default=1) parser.add_argument("--n_tpu_cores", type=int, default=0) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.") @@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir): ) parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--val_check_interval", default=1.0, type=float) -def generic_train(model: BaseTransformer, args: argparse.Namespace): +def generic_train( + model: BaseTransformer, + args: argparse.Namespace, + early_stopping_callback=False, + logger=True, # can pass WandbLogger() here + extra_callbacks=[], + checkpoint_callback=None, + logging_callback=None, + **extra_train_kwargs +): # init model set_seed(args) + odir = Path(model.hparams.output_dir) + odir.mkdir(exist_ok=True) + if checkpoint_callback is None: + checkpoint_callback = pl.callbacks.ModelCheckpoint( + filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 + ) + if logging_callback is None: + logging_callback = LoggingCallback() - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: - raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) - - checkpoint_callback = pl.callbacks.ModelCheckpoint( - filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5 - ) - - train_params = dict( - accumulate_grad_batches=args.gradient_accumulation_steps, - gpus=args.n_gpu, - max_epochs=args.num_train_epochs, - early_stop_callback=False, - gradient_clip_val=args.max_grad_norm, - checkpoint_callback=checkpoint_callback, - callbacks=[LoggingCallback()], - ) + train_params = {} if args.fp16: train_params["use_amp"] = args.fp16 @@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace): train_params["num_tpu_cores"] = args.n_tpu_cores train_params["gpus"] = 0 - if args.n_gpu > 1: + if args.gpus > 1: train_params["distributed_backend"] = "ddp" - trainer = pl.Trainer(**train_params) + trainer = pl.Trainer( + logger=logger, + accumulate_grad_batches=args.gradient_accumulation_steps, + gpus=args.gpus, + max_epochs=args.num_train_epochs, + early_stop_callback=early_stopping_callback, + gradient_clip_val=args.max_grad_norm, + checkpoint_callback=checkpoint_callback, + callbacks=[logging_callback] + extra_callbacks, + fast_dev_run=args.fast_dev_run, + val_check_interval=args.val_check_interval, + weights_summary=None, + resume_from_checkpoint=args.resume_from_checkpoint, + **train_params, + ) if args.do_train: trainer.fit(model) - + trainer.logger.log_hyperparams(args) + trainer.logger.save() return trainer diff --git a/examples/requirements.txt b/examples/requirements.txt index 474600d98d..05d716bdc0 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,5 +5,6 @@ psutil sacrebleu rouge-score tensorflow_datasets -pytorch-lightning==0.7.3 # April 10, 2020 release +pytorch-lightning==0.7.6 matplotlib +git-python==1.0.3 diff --git a/examples/summarization/README.md b/examples/summarization/README.md index ad4adc4b55..b626333cba 100644 --- a/examples/summarization/README.md +++ b/examples/summarization/README.md @@ -1,47 +1,70 @@ -### Get CNN Data -To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running: +### Data +CNN/DailyMail data ```bash +cd examples/summarization wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz tar -xzvf cnn_dm.tgz +export CNN_DIR=${PWD}/cnn_dm ``` this should make a directory called cnn_dm/ with files like `test.source`. To use your own data, copy that files format. Each article to be summarized is on its own line. +XSUM Data: +```bash +cd examples/summarization +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz +tar -xzvf xsum.tar.gz +export XSUM_DIR=${PWD}/xsum +``` + + ### Evaluation To create summaries for each article in dataset, run: ```bash -python evaluate_cnn.py test_generations.txt --score_path rouge_scores.txt +python run_eval.py test_generations.txt --score_path rouge_scores.txt ``` -The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system. +The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system. + ### Training -Run/modify `finetune_bart.sh` or `finetune_t5.sh` +Run/modify `finetune.sh` -### Stanford CoreNLP Setup +The following command should work on a 16GB GPU: +```bash +export me=`git config user.name` +./finetune.sh \ + --data_dir $XSUM_DIR \ + --train_batch_size=1 \ + --eval_batch_size=1 \ + --output_dir="$me"_xsum_results \ + --num_train_epochs 1 ``` -ptb_tokenize () { - cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2 -} -sudo apt install openjdk-8-jre-headless -sudo apt-get install ant -wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip -unzip stanford-corenlp-full-2018-10-05.zip -cd stanford-corenlp-full-2018-10-05 -export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar -``` -Then run `ptb_tokenize` on `test.target` and your generated hypotheses. -### Rouge Setup -Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge). -I also needed to run `sudo apt-get install libxml-parser-perl` +Tips: +- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100. +- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below) +- `fp16_opt_level=O1` (the default works best). +- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries. +(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). +- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. +Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. +- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. +- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. -```python -from files2rouge import files2rouge -from files2rouge import settings -files2rouge.run(, - , - saveto='rouge_output.txt') +### XSUM Shared Task +Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. +Here is an example command +```bash +export me=`git config user.name` +./finetune.sh \ + --data_dir $XSUM_DIR \ + --output_dir "$me"_xsum_frozen_embs \ + --logger wandb_shared \ + --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ + --num_train_epochs 6 ``` + +Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-) diff --git a/examples/summarization/callbacks.py b/examples/summarization/callbacks.py new file mode 100644 index 0000000000..6129d5f0b9 --- /dev/null +++ b/examples/summarization/callbacks.py @@ -0,0 +1,85 @@ +import logging +import os +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + + +def count_trainable_parameters(model): + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return params + + +logger = logging.getLogger(__name__) + + +class Seq2SeqLoggingCallback(pl.Callback): + def _write_logs( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True + ) -> None: + logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") + if not pl_module.is_logger(): + return + metrics = trainer.callback_metrics + trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) + # Log results + od = Path(pl_module.hparams.output_dir) + if type_path == "test": + results_file = od / "test_results.txt" + generations_file = od / "test_generations.txt" + else: + results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt" + generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt" + + with open(results_file, "a+") as writer: + for key in sorted(metrics): + if key in ["log", "progress_bar", "preds"]: + continue + val = metrics[key] + if isinstance(val, torch.Tensor): + val = val.item() + msg = f"{key}: {val:.6f}\n" + writer.write(msg) + + if not save_generations: + return + + if "preds" in metrics: + content = "\n".join(metrics["preds"]) + generations_file.open("w+").write(content) + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + try: + npars = pl_module.model.model.num_parameters() + except AttributeError: + npars = pl_module.model.num_parameters() + + n_trainable_pars = count_trainable_parameters(pl_module) + # mp stands for million parameters + trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) + + @rank_zero_only + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + return self._write_logs(trainer, pl_module, "val") + + @rank_zero_only + def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + return self._write_logs(trainer, pl_module, "test") + + +def get_rouge2_checkpoint_callback(output_dir): + """Saves the best model by validation ROUGE2 score.""" + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"), + monitor="val_rouge", + mode="max", + save_top_k=1, + period=0, # maybe save a checkpoint every time val is run, not just end of epoch. + ) + return checkpoint_callback diff --git a/examples/summarization/distillation.py b/examples/summarization/distillation.py new file mode 100644 index 0000000000..c9f5d5b04e --- /dev/null +++ b/examples/summarization/distillation.py @@ -0,0 +1,448 @@ +import argparse +import gc +import os +from pathlib import Path +from typing import List + +import pytorch_lightning as pl +import torch +from torch import nn +from torch.nn import functional as F + +from lightning_base import generic_train +from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration + + +try: + from .finetune import SummarizationModule + from .initialization_utils import init_student, copy_layers + from .utils import ( + use_task_specific_params, + SummarizationDataset, + pickle_load, + freeze_params, + assert_all_frozen, + any_requires_grad, + ) + from .finetune import main as ft_main +except ImportError: + from finetune import SummarizationModule + from finetune import main as ft_main + from initialization_utils import init_student, copy_layers + from utils import ( + use_task_specific_params, + SummarizationDataset, + pickle_load, + freeze_params, + assert_all_frozen, + any_requires_grad, + ) + + +class SummarizationDistiller(SummarizationModule): + loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] + + def __init__(self, hparams): + assert Path(hparams.data_dir).exists() + + d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams) + + super().__init__(hparams, model=student, config=student_cfg) + self.teacher = teacher + use_task_specific_params(self.teacher, "summarization") + freeze_params(self.teacher) + self.sanity_check_gradients() + self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") + self.temperature = 2.0 + self.alpha_mlm = hparams.alpha_mlm + self.alpha_ce = hparams.alpha_ce + self.alpha_hid = hparams.alpha_hid + # self.alpha_cos = hparams.alpha_cos + self.alpha_encoder_loss = self.hparams.alpha_encoder_loss + gc.collect() + torch.cuda.empty_cache() + + def sanity_check_gradients(self): + assert_all_frozen(self.teacher) + assert_all_frozen(self.model.model.decoder.embed_tokens) + assert_all_frozen(self.model.model.encoder.embed_tokens) + if self.different_encoder: + assert any_requires_grad(self.model.model.encoder) + else: + freeze_params(self.model.model.encoder) + del self.teacher.model.encoder + + def pre_init(self, hparams): + # Dump empty student model at a path, then call from_pretrained on it + teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() + student_updates = { + "decoder_layers": hparams.student_decoder_layers, + "encoder_layers": hparams.student_encoder_layers, + } + d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) + e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) + hparams.d_layer_to_copy = d_layers_to_copy + hparams.e_layer_to_copy = e_layers_to_copy + kw = teacher.config.to_diff_dict() + kw.update(student_updates) + # Copy weights + student_cfg = BartConfig(**kw) + student = BartForConditionalGeneration(student_cfg) + student, _ = init_student(student, teacher) + self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) + Path(hparams.output_dir).mkdir(exist_ok=True) + return d_layers_to_copy, student, student_cfg, teacher + + def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): + if teacher.config.model_type == "t5": + return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) + self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers + self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers + if self.different_decoder: + copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy) + if self.different_encoder: + copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy) + + def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): + self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers + self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers + if self.different_decoder: + copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy) + if self.different_encoder: + copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy) + + def get_dataset(self, type_path) -> SummarizationDataset: + n_obs = self.n_obs[type_path] + dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs) + return dataset + + def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: + if mask is not None: + # mask has False at padding_idx + sel_mask = mask[:, :, None].expand_as(student_outputs).bool() + s_logits_slct = torch.masked_select(student_outputs, sel_mask) + t_logits_slct = torch.masked_select(teacher_outputs, sel_mask) + else: + t_logits_slct = teacher_outputs + s_logits_slct = student_outputs + return F.mse_loss(s_logits_slct, t_logits_slct) + + def calc_ce_loss(self, mask, s_logits, t_logits): + if mask is not None: + # mask has False at padding_idx + sel_mask = mask[:, :, None].expand_as(s_logits) + s_logits_slct = torch.masked_select( + s_logits, sel_mask + ) # (bs * seq_length * voc_size) modulo the 1s in mask + t_logits_slct = torch.masked_select( + t_logits, sel_mask + ) # (bs * seq_length * voc_size) modulo the 1s in mask + else: + t_logits_slct = t_logits + s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask + s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + assert t_logits_slct.size() == s_logits_slct.size() + loss_ce = ( + self.ce_loss_fct( + F.log_softmax(s_logits_slct / self.temperature, dim=-1), + F.softmax(t_logits_slct / self.temperature, dim=-1), + ) + * (self.temperature) ** 2 + ) + return loss_ce, s_logits_slct, t_logits_slct + + def configure_optimizers(self): + "Prepare optimizer and schedule (linear warmup and decay)" + + model = self.model + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": self.hparams.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) + self.opt = optimizer + return [optimizer] + + @staticmethod + def add_model_specific_args(parser, root_dir): + SummarizationModule.add_model_specific_args(parser, root_dir) + parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) + parser.add_argument("--alpha_ce", default=0.8, type=float) + parser.add_argument("--alpha_mlm", default=0.2, type=float) + # parser.add_argument("--alpha_cos", default=0.0, type=float) + parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) + parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) + parser.add_argument( + "--student_decoder_layers", default=12, type=int, required=False, + ) + parser.add_argument( + "--student_encoder_layers", default=12, type=int, required=False, + ) + parser.add_argument( + "--no_teacher", action="store_true", default=False, + ) + parser.add_argument( # TODO: remove + "--enc_only", action="store_true", default=False, + ) + return parser + + def _step(self, batch): + # assert is_frozen(self.teacher) + pad_token_id = self.tokenizer.pad_token_id + input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] + decoder_input_ids = y[:, :-1].contiguous() + labels = y[:, 1:].clone() + labels[y[:, 1:] == pad_token_id] = -100 + # noinspection PyCallingNonCallable + sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self( + input_ids, + attention_mask=src_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, + output_hidden_states=True, + output_attentions=False, + ) + + def zero_tensor(): + return torch.tensor(0.0).type_as(sloss) + + loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() + if self.different_encoder: + with torch.no_grad(): + teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder( + input_ids, attention_mask=src_mask, output_hidden_states=True + ) + if self.hparams.alpha_encoder_loss > 0: + loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) + + hid_loss_enc = self.calc_hidden_loss( + src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy + ) + + teacher_enc_outputs = (enc_outputs,) + assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) + + with torch.no_grad(): + tloss, tlogits, tdec_hidden, _ = self.teacher( + input_ids, + attention_mask=src_mask, + encoder_outputs=teacher_enc_outputs, + decoder_input_ids=decoder_input_ids, + lm_labels=labels, + output_hidden_states=True, + ) + dec_mask = decoder_input_ids.ne(pad_token_id) + loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) + if self.alpha_hid > 0: + hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) + + blended_loss = ( + self.alpha_ce * loss_ce + + self.alpha_mlm * sloss + + self.hparams.alpha_encoder_loss * loss_encoder + + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) + ) + return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec + + def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches): + assert not isinstance( + hidden_states, torch.Tensor + ), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}" + assert not isinstance( + hidden_states_T, torch.Tensor + ), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}" + mask = attention_mask.to(hidden_states[0]) + valid_count = mask.sum() * hidden_states[0].size(-1) + hidden_losses = [ + (F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum() + / valid_count + for i, j in enumerate(matches) + ] + return sum(hidden_losses) + + +class T5SummarizationDistiller(SummarizationDistiller): + def pre_init(self, hparams): + teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) + n_layer = hparams.student_decoder_layers + assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this + d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block)) + e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block)) + student_updates = {"num_layers": n_layer} + hparams.d_layer_to_copy = d_layers_to_copy + hparams.e_layer_to_copy = e_layers_to_copy + kw = teacher.config.to_diff_dict() + + kw.update(student_updates) + # Copy weights + student_cfg = T5Config(**kw) + student = T5ForConditionalGeneration(student_cfg) + student, _ = init_student(student, teacher) + self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) + Path(hparams.output_dir).mkdir(exist_ok=True) + task_specific_params = student.config.task_specific_params + if task_specific_params is not None: + student.config.update(task_specific_params.get("summarization", {})) + return d_layers_to_copy, student, student_cfg, teacher + + def freeze_embeds(self): + freeze_params(self.model.shared) + for d in [self.model.encoder, self.model.decoder]: + freeze_params(d.embed_tokens) + + def sanity_check_gradients(self): + """T5""" + assert_all_frozen(self.teacher) + assert_all_frozen(self.model.decoder.embed_tokens) + assert_all_frozen(self.model.encoder.embed_tokens) + if self.different_encoder: + assert any_requires_grad(self.model.encoder) + else: + freeze_params(self.model.encoder) + del self.teacher.model.encoder + if self.different_decoder: + assert any_requires_grad(self.model.decoder) + else: + freeze_params(self.model.decoder) # TODO(SS): very suspicious + + def _step(self, batch): + pad_token_id = self.tokenizer.pad_token_id + source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] + decoder_input_ids = y[:, :-1].contiguous() + labels = y[:, 1:].clone() + labels[y[:, 1:] == pad_token_id] = -100 + # noinspection PyCallingNonCallable + dec_mask = decoder_input_ids.ne(pad_token_id) + + sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self( + source_ids, + attention_mask=source_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, + output_hidden_states=True, + output_attentions=False, + use_cache=False, + ) + + def zero_tensor(): + return torch.tensor(0.0).type_as(sloss) + + loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() + if self.different_encoder: + with torch.no_grad(): + teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder( + source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False, + ) + if self.hparams.alpha_encoder_loss > 0: + loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask) + + hid_loss_enc = self.calc_hidden_loss( + source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy + ) + + teacher_enc_outputs = (enc_outputs,) + assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) + + with torch.no_grad(): + tloss, tlogits, tdec_hidden, _ = self.teacher( + source_ids, + attention_mask=source_mask, + encoder_outputs=teacher_enc_outputs, + decoder_input_ids=decoder_input_ids, + lm_labels=labels, + output_hidden_states=True, + use_cache=False, + ) + + loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits) + if self.alpha_hid > 0: + hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy) + + blended_loss = ( + self.alpha_ce * loss_ce + + self.alpha_mlm * sloss + + self.hparams.alpha_encoder_loss * loss_encoder + + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) + ) + return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec + + +def create_module(args): + t5 = "t5" in args.model_name_or_path + if args.no_teacher: + assert not args.enc_only + module_cls = SummarizationModule + elif t5: + module_cls = T5SummarizationDistiller + elif args.enc_only: + raise ValueError("Deleted that") + else: + module_cls = SummarizationDistiller + args.setup_cls: str = module_cls.__name__ + model = module_cls(args) + return model + + +def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): + exp_dir = ckpt_path.parent + if dest_dir is None: + dest_dir = exp_dir + clash = list(dest_dir.glob("test_generations*")) + if clash: + print(f"SKIPPING to avoid overwriting {clash}") + ckpt = torch.load(ckpt_path, map_location="cpu") + if "hparams" in ckpt: + args = argparse.Namespace(**ckpt["hparams"]) + else: + args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl")) + args.resume_from_checkpoint = str(ckpt_path) + args.do_train = False + args.output_dir = str(dest_dir) + args.n_gpu = 1 + args.eval_batch_size = 16 + Path(args.output_dir).mkdir(exist_ok=True) + model = create_module(args) + trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False) + trainer.test(model) + + +def get_layers_to_copy(n_to_get, tot): + all_layers = list(range(tot)) + if tot == 12: # Alternating for special cases + layers_to_copy = { # maps # layers in student -> which teacher layers to copy + 6: [0, 2, 4, 7, 9, 11], + 1: [11], + 3: [0, 6, 11], + 2: [0, 11], + 4: [0, 4, 8, 11], + 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], + 12: all_layers, + } + return layers_to_copy[n_to_get] + else: + return all_layers[:n_to_get] + + +def distill_main(args): + Path(args.output_dir).mkdir(exist_ok=True) + if len(os.listdir(args.output_dir)) > 3 and args.do_train: + raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) + + model = create_module(args) + return ft_main(args, model=model) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + + distill_main(args) diff --git a/examples/summarization/evaluate_cnn.py b/examples/summarization/evaluate_cnn.py deleted file mode 100644 index f2b57f09cd..0000000000 --- a/examples/summarization/evaluate_cnn.py +++ /dev/null @@ -1,100 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from rouge_score import rouge_scorer, scoring -from tqdm import tqdm - -from transformers import AutoModelWithLMHead, AutoTokenizer - - -DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -def generate_summaries( - examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE -): - fout = Path(out_file).open("w", encoding="utf-8") - model = AutoModelWithLMHead.from_pretrained(model_name).to(device) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - - # update config with summarization specific params - task_specific_params = model.config.task_specific_params - if task_specific_params is not None: - model.config.update(task_specific_params.get("summarization", {})) - - for batch in tqdm(list(chunks(examples, batch_size))): - if "t5" in model_name: - batch = [model.config.prefix + text for text in batch] - dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to( - device - ) - summaries = model.generate(**dct) - - dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) - for hypothesis in dec: - fout.write(hypothesis + "\n") - fout.flush() - - -def calculate_rouge(output_lns, reference_lns, score_path): - score_file = Path(score_path).open("w") - scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) - aggregator = scoring.BootstrapAggregator() - - for reference_ln, output_ln in zip(reference_lns, output_lns): - scores = scorer.score(reference_ln, output_ln) - aggregator.add_scores(scores) - - result = aggregator.aggregate() - score_file.write( - "ROUGE_1: \n{} \n\n ROUGE_2: \n{} \n\n ROUGE_L: \n{} \n\n".format( - result["rouge1"], result["rouge2"], result["rougeL"] - ) - ) - - -def run_generate(): - parser = argparse.ArgumentParser() - parser.add_argument( - "input_path", type=str, help="like cnn_dm/test.source or cnn_dm/test_articles_input.txt", - ) - parser.add_argument( - "output_path", type=str, help="where to save summaries", - ) - parser.add_argument( - "model_name", - type=str, - default="facebook/bart-large-cnn", - help="like bart-large-cnn,'t5-small', 't5-base', 't5-large', 't5-3b', 't5-11b", - ) - parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") - parser.add_argument( - "--score_path", type=str, required=False, help="where to save the rouge score", - ) - parser.add_argument( - "--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.", - ) - parser.add_argument( - "--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time", - ) - args = parser.parse_args() - examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] - - generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device) - if args.score_path is not None: - output_lns = [x.rstrip() for x in open(args.output_path).readlines()] - reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] - - calculate_rouge(output_lns, reference_lns, args.score_path) - - -if __name__ == "__main__": - run_generate() diff --git a/examples/summarization/finetune.py b/examples/summarization/finetune.py index 078491f918..56a1984635 100644 --- a/examples/summarization/finetune.py +++ b/examples/summarization/finetune.py @@ -3,91 +3,169 @@ import glob import logging import os import time +from pathlib import Path +from typing import Dict, List, Tuple +import numpy as np +import pytorch_lightning as pl import torch from torch.utils.data import DataLoader -from lightning_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup +from lightning_base import BaseTransformer, add_generic_args, generic_train +from transformers import get_linear_schedule_with_warmup try: - from .utils import SummarizationDataset + from .utils import ( + use_task_specific_params, + SummarizationDataset, + lmap, + flatten_list, + pickle_save, + save_git_info, + freeze_params, + calculate_rouge, + get_git_info, + ROUGE_KEYS, + ) + from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback except ImportError: - from utils import SummarizationDataset - + from utils import ( + use_task_specific_params, + SummarizationDataset, + lmap, + flatten_list, + pickle_save, + save_git_info, + freeze_params, + calculate_rouge, + get_git_info, + ROUGE_KEYS, + ) + from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback logger = logging.getLogger(__name__) -class SummarizationTrainer(BaseTransformer): +class SummarizationModule(BaseTransformer): + mode = "summarization" + loss_names = ["loss"] - mode = "language-modeling" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) + use_task_specific_params(self.model, "summarization") + save_git_info(self.hparams.output_dir) + self.metrics_save_path = Path(self.output_dir) / "metrics.pkl" + self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" + self.step_count = 0 + self.metrics = {"train": [], "val": [], "test": []} - def __init__(self, hparams): - super().__init__(hparams, num_labels=None, mode=self.mode) self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, max_source_length=self.hparams.max_source_length, - max_target_length=self.hparams.max_target_length, + prefix=self.model.config.prefix or "", ) + n_observations_per_split = { + "train": self.hparams.n_train, + "val": self.hparams.n_val, + "test": self.hparams.n_test, + } + self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} - def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): - return self.model( - input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels, + self.target_lens = { + "train": self.hparams.max_target_length, + "val": self.hparams.val_max_target_length, + "test": self.hparams.test_max_target_length, + } + assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" + assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" + + if self.hparams.freeze_embeds: + self.freeze_embeds() + if self.hparams.freeze_encoder: + freeze_params(self.model.model.encoder) # TODO: this will break for t5 + self.hparams.git_sha = get_git_info()["repo_sha"] + self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu + + def freeze_embeds(self): + """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" + if self.model.config.model_type == "bart": + freeze_params(self.model.model.shared) + for d in [self.model.model.encoder, self.model.model.decoder]: + freeze_params(d.embed_positions) + freeze_params(d.embed_tokens) + else: + freeze_params(self.model.shared) + for d in [self.model.encoder, self.model.decoder]: + freeze_params(d.embed_tokens) + + def forward(self, input_ids, **kwargs): + return self.model(input_ids, **kwargs) + + def ids_to_clean_text(self, generated_ids: List[int]): + gen_text = self.tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) + return lmap(str.strip, gen_text) - def _step(self, batch): + def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"] + source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] y_ids = y[:, :-1].contiguous() lm_labels = y[:, 1:].clone() lm_labels[y[:, 1:] == pad_token_id] = -100 - outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,) - + outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,) loss = outputs[0] + return (loss,) - return loss + def training_step(self, batch, batch_idx) -> Dict: + loss_tensors = self._step(batch) + logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + return {"loss": loss_tensors[0], "log": logs} - def training_step(self, batch, batch_idx): - loss = self._step(batch) + def validation_step(self, batch, batch_idx) -> Dict: + return self._generative_step(batch) - tensorboard_logs = {"train_loss": loss} - return {"loss": loss, "log": tensorboard_logs} + def validation_end(self, outputs, prefix="val") -> Dict: + self.step_count += 1 + losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} + loss = losses["loss"] + rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]} + rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss) + rouges.update({k: v.item() for k, v in losses.items()}) + losses.update(rouges) + metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} + metrics["step_count"] = self.step_count + self.save_metrics(metrics, prefix) # writes to self.metrics_save_path + preds = flatten_list([x["preds"] for x in outputs]) + return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor} - def validation_step(self, batch, batch_idx): - loss = self._step(batch) - return {"val_loss": loss} + def save_metrics(self, metrics, prefix) -> None: + self.metrics[prefix].append(metrics) + pickle_save(self.metrics, self.metrics_save_path) - def validation_end(self, outputs): - avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() - tensorboard_logs = {"val_loss": avg_loss} - return {"avg_val_loss": avg_loss, "log": tensorboard_logs} - - def test_step(self, batch, batch_idx): + def _generative_step(self, batch): pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) - # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py - generated_ids = self.model.generate( - input_ids=source_ids, - attention_mask=source_mask, - num_beams=1, - max_length=80, - repetition_penalty=2.5, - length_penalty=1.0, - early_stopping=True, - use_cache=True, - ) - preds = [ - self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) - for g in generated_ids - ] - target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y] - loss = self._step(batch) + # TODO(SS): task specific params - return {"val_loss": loss, "preds": preds, "target": target} + t0 = time.time() + generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,) + gen_time = time.time() - t0 + preds = self.ids_to_clean_text(generated_ids) + target = self.ids_to_clean_text(y) + loss_tensors = self._step(batch) + base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} + rouge: Dict = calculate_rouge(preds, target) + summ_len = np.mean(lmap(len, generated_ids)) + base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge) + return base_metrics + + def test_step(self, batch, batch_idx): + return self._generative_step(batch) def test_end(self, outputs): - return self.validation_end(outputs) + return self.validation_end(outputs, prefix="test") def test_epoch_end(self, outputs): output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") @@ -102,15 +180,43 @@ class SummarizationTrainer(BaseTransformer): return self.test_end(outputs) + def validation_epoch_end(self, outputs): + self.validation_end(outputs, "val") + + def get_dataset(self, type_path) -> SummarizationDataset: + n_obs = self.n_obs[type_path] + max_target_length = self.target_lens[type_path] + dataset = SummarizationDataset( + self.tokenizer, + type_path=type_path, + n_obs=n_obs, + max_target_length=max_target_length, + **self.dataset_kwargs, + ) + return dataset + def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: - dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) - dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle) + dataset = self.get_dataset(type_path) + sampler = None + if self.hparams.sortish_sampler and type_path == "train": + assert self.hparams.gpus <= 1 # TODO: assert earlier + sampler = dataset.make_sortish_sampler(batch_size) + shuffle = False + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=dataset.collate_fn, + shuffle=shuffle, + num_workers=self.num_workers, + sampler=sampler, + ) return dataloader def train_dataloader(self) -> DataLoader: dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) t_total = ( - (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) + (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) // self.hparams.gradient_accumulation_steps * float(self.hparams.num_train_epochs) ) @@ -129,7 +235,7 @@ class SummarizationTrainer(BaseTransformer): @staticmethod def add_model_specific_args(parser, root_dir): BaseTransformer.add_model_specific_args(parser, root_dir) - # Add BART specific options + add_generic_args(parser, root_dir) parser.add_argument( "--max_source_length", default=1024, @@ -144,41 +250,82 @@ class SummarizationTrainer(BaseTransformer): help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) - + parser.add_argument( + "--val_max_target_length", + default=142, # these defaults are optimized for CNNDM. For xsum, see README.md. + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--test_max_target_length", + default=142, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) parser.add_argument( "--data_dir", - default=None, type=str, required=True, - help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.", + help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target", ) + parser.add_argument("--freeze_encoder", action="store_true") + parser.add_argument("--freeze_embeds", action="store_true") + parser.add_argument("--sortish_sampler", action="store_true", default=False) + parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default") + parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") + parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") return parser -def main(args): +def main(args, model=None) -> SummarizationModule: + Path(args.output_dir).mkdir(exist_ok=True) + if len(os.listdir(args.output_dir)) > 3 and args.do_train: + raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) + if model is None: + model: BaseTransformer = SummarizationModule(args) + if ( + args.logger == "default" + or args.fast_dev_run + or str(args.output_dir).startswith("/tmp") + or str(args.output_dir).startswith("/var") + ): + logger = True # don't pollute wandb logs unnecessarily + elif args.logger == "wandb": + from pytorch_lightning.loggers import WandbLogger - # If output_dir not provided, a folder will be generated in pwd - if not args.output_dir: - args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",) - os.makedirs(args.output_dir) - model = SummarizationTrainer(args) - trainer = generic_train(model, args) + logger = WandbLogger(name=model.output_dir.name) + elif args.logger == "wandb_shared": + from pytorch_lightning.loggers import WandbLogger - # Optionally, predict on dev set and write to output_dir - if args.do_predict: - # See https://github.com/huggingface/transformers/issues/3159 - # pl use this format to create a checkpoint: - # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ - # /pytorch_lightning/callbacks/model_checkpoint.py#L169 - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) - model = model.load_from_checkpoint(checkpoints[-1]) - trainer.test(model) + # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB. + logger = WandbLogger(name=model.output_dir.name, project="hf_summarization") + trainer: pl.Trainer = generic_train( + model, + args, + logging_callback=Seq2SeqLoggingCallback(), + checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir), + logger=logger, + # TODO: early stopping callback seems messed up + ) + if not args.do_predict: + return model + + model.hparams.test_checkpoint = "" + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) + if checkpoints: + model.hparams.test_checkpoint = checkpoints[-1] + trainer.resume_from_checkpoint = checkpoints[-1] + trainer.logger.log_hyperparams(model.hparams) + trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics. + return model if __name__ == "__main__": parser = argparse.ArgumentParser() - add_generic_args(parser, os.getcwd()) - parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd()) + parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() main(args) diff --git a/examples/summarization/finetune.sh b/examples/summarization/finetune.sh new file mode 100755 index 0000000000..ead55c5892 --- /dev/null +++ b/examples/summarization/finetune.sh @@ -0,0 +1,23 @@ +export OUTPUT_DIR=bart_cnn_finetune + +# Make output directory if it doesn't exist +mkdir -p $OUTPUT_DIR + +# Add parent directory to python path to access lightning_base.py +export PYTHONPATH="../":"${PYTHONPATH}" + + +# --model_name_or_path=t5-base for t5 + +python finetune.py \ + --model_name_or_path=facebook/bart-large \ + --learning_rate=3e-5 \ + --fp16 \ + --gpus 1 \ + --do_train \ + --do_predict \ + --n_val 1000 \ + --val_check_interval 0.1 \ + --sortish_sampler \ + --max_target_length=56 \ + $@ diff --git a/examples/summarization/finetune_bart.sh b/examples/summarization/finetune_bart.sh deleted file mode 100644 index b37888f5f4..0000000000 --- a/examples/summarization/finetune_bart.sh +++ /dev/null @@ -1,18 +0,0 @@ -export OUTPUT_DIR_NAME=bart_sum -export CURRENT_DIR=${PWD} -export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} - -# Make output directory if it doesn't exist -mkdir -p $OUTPUT_DIR - -# Add parent directory to python path to access lightning_base.py -export PYTHONPATH="../":"${PYTHONPATH}" - -python finetune.py \ ---data_dir=./cnn-dailymail/cnn_dm \ ---model_name_or_path=bart-large \ ---learning_rate=3e-5 \ ---train_batch_size=4 \ ---eval_batch_size=4 \ ---output_dir=$OUTPUT_DIR \ ---do_train $@ diff --git a/examples/summarization/initialization_utils.py b/examples/summarization/initialization_utils.py new file mode 100644 index 0000000000..02cba8b352 --- /dev/null +++ b/examples/summarization/initialization_utils.py @@ -0,0 +1,20 @@ +from typing import List + +from torch import nn + + +def init_student(student, teacher): + teacher_state_dict = teacher.state_dict() + info = student.load_state_dict(teacher_state_dict, strict=False) + assert info.missing_keys == [], info.missing_keys + return student, info + + +def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]): + copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy) + + +def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None: + layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy]) + assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}" + student_layers.load_state_dict(layers_to_copy.state_dict()) diff --git a/examples/summarization/run_distiller.sh b/examples/summarization/run_distiller.sh new file mode 100755 index 0000000000..6fbecad388 --- /dev/null +++ b/examples/summarization/run_distiller.sh @@ -0,0 +1,12 @@ +#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm + +# Add parent directory to python path to access lightning_base.py +export PYTHONPATH="../":"${PYTHONPATH}" + +python distillation.py \ +--learning_rate=3e-4 \ +--do_train \ +--do_predict \ +--fp16 \ +--val_check_interval 0.1 \ +$@ diff --git a/examples/summarization/run_eval.py b/examples/summarization/run_eval.py new file mode 100644 index 0000000000..3013743c89 --- /dev/null +++ b/examples/summarization/run_eval.py @@ -0,0 +1,78 @@ +import argparse +import json +from pathlib import Path + +import torch +from tqdm import tqdm + +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + +try: + from .finetune import calculate_rouge, use_task_specific_params +except ImportError: + from finetune import calculate_rouge, use_task_specific_params + +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def generate_summaries( + examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, +) -> None: + fout = Path(out_file).open("w", encoding="utf-8") + model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + if fp16: + model = model.half() + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # update config with summarization specific params + use_task_specific_params(model, "summarization") + + for batch in tqdm(list(chunks(examples, batch_size))): + if "t5" in model_name: + batch = [model.config.prefix + text for text in batch] + dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to( + device + ) + summaries = model.generate(**dct) + + dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for hypothesis in dec: + fout.write(hypothesis + "\n") + fout.flush() + + +def run_generate(): + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str, help="like cnn_dm/test.source") + parser.add_argument("output_path", type=str, help="where to save summaries") + parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") + parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") + parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") + parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") + parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") + parser.add_argument("--fp16", action="store_true") + args = parser.parse_args() + examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] + + generate_summaries( + examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16 + ) + if args.score_path is not None: + output_lns = [x.rstrip() for x in open(args.output_path).readlines()] + reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] + + rouge: dict = calculate_rouge(output_lns, reference_lns) + + json.dump(rouge, open("score_path", "w+")) + + +if __name__ == "__main__": + run_generate() diff --git a/examples/summarization/test_summarization_examples.py b/examples/summarization/test_summarization_examples.py index fd0ff38be3..9688d7f88a 100644 --- a/examples/summarization/test_summarization_examples.py +++ b/examples/summarization/test_summarization_examples.py @@ -7,28 +7,40 @@ import unittest from pathlib import Path from unittest.mock import patch +import torch from torch.utils.data import DataLoader from transformers import BartTokenizer -from .evaluate_cnn import run_generate +from .distillation import distill_main, evaluate_checkpoint from .finetune import main -from .utils import SummarizationDataset +from .run_eval import generate_summaries, run_generate +from .utils import SummarizationDataset, lmap, pickle_load logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() - -DEFAULT_ARGS = { +FP16_EVER = False +CHEAP_ARGS = { + "logger": "default", + "alpha_hid": 0, + "freeze_embeds": True, + "enc_only": False, + "tgt_suffix": "", + "resume_from_checkpoint": None, + "sortish_sampler": True, + "student_decoder_layers": 1, + "val_check_interval": 1.0, "output_dir": "", "fp16": False, + "no_teacher": False, "fp16_opt_level": "O1", - "n_gpu": 1, + "gpus": 1 if torch.cuda.is_available() else 0, "n_tpu_cores": 0, "max_grad_norm": 1.0, "do_train": True, - "do_predict": False, + "do_predict": True, "gradient_accumulation_steps": 1, "server_ip": "", "server_port": "", @@ -36,7 +48,7 @@ DEFAULT_ARGS = { "model_type": "bart", "model_name_or_path": "sshleifer/bart-tiny-random", "config_name": "", - "tokenizer_name": "", + "tokenizer_name": "facebook/bart-large", "cache_dir": "", "do_lower_case": False, "learning_rate": 3e-05, @@ -48,6 +60,17 @@ DEFAULT_ARGS = { "eval_batch_size": 2, "max_source_length": 12, "max_target_length": 12, + "val_max_target_length": 12, + "test_max_target_length": 12, + "fast_dev_run": False, + "no_cache": False, + "n_train": -1, + "n_val": -1, + "n_test": -1, + "student_encoder_layers": 1, + "alpha_loss_encoder": 0.0, + "freeze_encoder": False, + "auto_scale_batch_size": False, } @@ -56,6 +79,9 @@ def _dump_articles(path: Path, articles: list): f.write("\n".join(articles)) +BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute() + + def make_test_data_dir(): tmp_dir = Path(tempfile.gettempdir()) articles = [" Sam ate lunch today", "Sams lunch ingredients"] @@ -66,6 +92,169 @@ def make_test_data_dir(): return tmp_dir +@unittest.skip("These wont' pass until hidden_states kwarg is merged.") +class TestSummarizationDistiller(unittest.TestCase): + @classmethod + def setUpClass(cls): + logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks + return cls + + @unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test") + def test_bdc_multigpu(self): + updates = dict( + student_encoder_layers=2, + student_decoder_layers=1, + no_teacher=True, + freeze_encoder=True, + gpus=2, + sortish_sampler=False, + ) + self._bart_distiller_cli(updates) + + @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") + def test_bdc_fp16(self): + updates = dict( + student_encoder_layers=2, + student_decoder_layers=1, + alpha_hid=3.0, + freeze_encoder=True, + gpus=1, + fp16=FP16_EVER, + fp16_opt_level="O1", + ) + self._bart_distiller_cli(updates) + + @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") + def test_bdc_t5_eval_fp16(self): + updates = dict( + fp16=FP16_EVER, + gpus=1, + model_type="t5", + model_name_or_path="patrickvonplaten/t5-tiny-random", + do_train=False, + do_predict=True, + tokenizer_name=None, + no_teacher=True, + ) + self._bart_distiller_cli(updates, check_contents=False) + + @unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test") + def test_bdc_t5_train_fp16(self): + updates = dict( + fp16=FP16_EVER, + gpus=1, + model_type="t5", + model_name_or_path="patrickvonplaten/t5-tiny-random", + do_train=True, + do_predict=True, + tokenizer_name="patrickvonplaten/t5-tiny-random", + no_teacher=True, + ) + self._bart_distiller_cli(updates) + + def test_bdc_no_teacher(self): + updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,) + self._bart_distiller_cli(updates) + + def test_bdc_yes_teacher(self): + updates = dict(student_encoder_layers=2, student_decoder_layers=1,) + self._bart_distiller_cli(updates) + + def test_bdc_checkpointing(self): + + updates = dict( + student_encoder_layers=2, + student_decoder_layers=1, + num_train_epochs=4, + val_check_interval=0.25, + alpha_hid=2.0, + ) + model = self._bart_distiller_cli(updates, check_contents=False) + + ckpts = list(Path(model.output_dir).glob("*.ckpt")) + self.assertEqual(1, len(ckpts)) + transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) + self.assertEqual(len(transformer_ckpts), len(ckpts)) + new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) + self.assertEqual(len(new_transformer_ckpts), 1) + examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) + out_path = tempfile.mktemp() + generate_summaries(examples, out_path, new_transformer_ckpts[0].parent) + self.assertTrue(Path(out_path).exists()) + + evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) + + def test_bdc_t5(self): + updates = dict( + student_encoder_layers=1, + student_decoder_layers=1, + alpha_hid=2.0, + teacher="patrickvonplaten/t5-tiny-random", + model_type="t5", + model_name_or_path="patrickvonplaten/t5-tiny-random", + tokenizer_name="patrickvonplaten/t5-tiny-random", + ) + self._bart_distiller_cli(updates) + + def test_bdc_t5_eval(self): + updates = dict( + model_type="t5", + model_name_or_path="patrickvonplaten/t5-tiny-random", + do_train=False, + do_predict=True, + tokenizer_name="patrickvonplaten/t5-tiny-random", + no_teacher=True, + ) + self._bart_distiller_cli(updates, check_contents=False) + + def _bart_distiller_cli(self, updates, check_contents=True): + default_updates = dict( + model_type="bart", + train_batch_size=1, + eval_batch_size=2, + num_train_epochs=2, + alpha_mlm=0.2, + alpha_ce=0.8, + do_predict=True, + gpus=1 if torch.cuda.is_available() else 0, + model_name_or_path="sshleifer/tinier_bart", + teacher=CHEAP_ARGS["model_name_or_path"], + val_check_interval=0.5, + alpha_encoder_loss=0.4, + ) + default_updates.update(updates) + args_d: dict = CHEAP_ARGS.copy() + tmp_dir = make_test_data_dir() + output_dir = tempfile.mkdtemp(prefix="output_") + + args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) + model = distill_main(argparse.Namespace(**args_d)) + if not check_contents: + return model + contents = os.listdir(output_dir) + ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt" + contents = {os.path.basename(p) for p in contents} + self.assertIn(ckpt_name, contents) + self.assertIn("metrics.pkl", contents) + self.assertIn("test_generations.txt", contents) + self.assertIn("val_generations_1.txt", contents) + self.assertIn("val_1_results.txt", contents) + self.assertIn("test_results.txt", contents) + # self.assertEqual(len(contents), 15) + + metrics = pickle_load(Path(output_dir) / "metrics.pkl") + import pandas as pd + + val_df = pd.DataFrame(metrics["val"]) + train_df = pd.DataFrame(metrics["train"]) + test_df = pd.DataFrame(metrics["test"]) + desired_n_evals = args_d["num_train_epochs"] * 2 + 1 + self.assertEqual(val_df.shape[0], desired_n_evals) # + self.assertEqual(test_df.shape[1], val_df.shape[1]) + self.assertEqual(train_df.shape[0], 0) + return model + + class TestBartExamples(unittest.TestCase): @classmethod def setUpClass(cls): @@ -79,49 +268,31 @@ class TestBartExamples(unittest.TestCase): output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] _dump_articles(tmp, articles) - testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] + testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] with patch.object(sys, "argv", testargs): run_generate() self.assertTrue(Path(output_file_name).exists()) os.remove(Path(output_file_name)) - def test_bart_run_sum_cli(self): - args_d: dict = DEFAULT_ARGS.copy() - tmp_dir = make_test_data_dir() - output_dir = tempfile.mkdtemp(prefix="output_") - args_d.update( - data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir, - ) - main(argparse.Namespace(**args_d)) - args_d.update({"do_train": False, "do_predict": True}) - - main(argparse.Namespace(**args_d)) - contents = os.listdir(output_dir) - expected_contents = { - "checkpointepoch=0.ckpt", - "test_results.txt", - } - created_files = {os.path.basename(p) for p in contents} - self.assertSetEqual(expected_contents, created_files) - def test_t5_run_sum_cli(self): - args_d: dict = DEFAULT_ARGS.copy() + args_d: dict = CHEAP_ARGS.copy() + tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update( data_dir=tmp_dir, model_type="t5", model_name_or_path="patrickvonplaten/t5-tiny-random", + tokenizer_name=None, # "patrickvonplaten/t5-tiny-random", train_batch_size=2, eval_batch_size=2, - n_gpu=0, + gpus=0, output_dir=output_dir, do_predict=True, ) - main(argparse.Namespace(**args_d)) - - # args_d.update({"do_train": False, "do_predict": True}) - # main(argparse.Namespace(**args_d)) + assert "n_train" in args_d + args = argparse.Namespace(**args_d) + main(args) def test_bart_summarization_dataset(self): tmp_dir = Path(tempfile.gettempdir()) @@ -138,42 +309,16 @@ class TestBartExamples(unittest.TestCase): ) dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) for batch in dataloader: - self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape) + self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape) # show that articles were trimmed. - self.assertEqual(batch["source_ids"].shape[1], max_len_source) - self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly + self.assertEqual(batch["input_ids"].shape[1], max_len_source) + self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly # show that targets were truncated - self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated + self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated self.assertGreater(max_len_target, trunc_target) # Truncated -class TestT5Examples(unittest.TestCase): - def test_t5_cli(self): - output_file_name = "output_t5_sum.txt" - score_file_name = "score_t5_sum.txt" - articles = ["New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] - stream_handler = logging.StreamHandler(sys.stdout) - logger.addHandler(stream_handler) - tmp = Path(tempfile.gettempdir()) / "utest_generations_t5_sum.hypo" - with tmp.open("w", encoding="utf-8") as f: - f.write("\n".join(articles)) - - output_file_name = Path(tempfile.gettempdir()) / "utest_output_t5_sum.hypo" - score_file_name = Path(tempfile.gettempdir()) / "utest_score_t5_sum.hypo" - - testargs = [ - "evaluate_cnn.py", - str(tmp), - str(output_file_name), - "patrickvonplaten/t5-tiny-random", - "--reference_path", - str(tmp), - "--score_path", - str(score_file_name), - ] - - with patch.object(sys, "argv", testargs): - run_generate() - self.assertTrue(Path(output_file_name).exists()) - self.assertTrue(Path(score_file_name).exists()) +def list_to_text_file(lst, path): + dest = Path(path) + dest.open("w+").writelines(lst) diff --git a/examples/summarization/utils.py b/examples/summarization/utils.py index 874ec2b4a5..a375d823ed 100644 --- a/examples/summarization/utils.py +++ b/examples/summarization/utils.py @@ -1,20 +1,66 @@ +import itertools +import json import os +import pickle +from pathlib import Path +from typing import Dict, Iterable, List +import git +import numpy as np import torch -from torch.utils.data import Dataset +from rouge_score import rouge_scorer, scoring +from torch import nn +from torch.utils.data import Dataset, Sampler +from tqdm import tqdm + +from transformers import BartTokenizer -def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): +def encode_file( + tokenizer, + data_path, + max_length, + pad_to_max_length=True, + return_tensors="pt", + overwrite_cache=False, + prefix="", + tok_name="", +): + cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt") + if not overwrite_cache and cache_path.exists(): + try: + examples = torch.load(cache_path) + assert isinstance(examples, list) + return examples + + except Exception: + print(f"failed to load from {cache_path}, retokenizing {data_path}") + data_path = Path(data_path) + + lns = lmap(str.strip, data_path.open().readlines()) + lns = [prefix + text for text in lns] + assert lns, f"found empty file at {data_path}" examples = [] - with open(data_path, "r") as f: - for text in f.readlines(): - tokenized = tokenizer.batch_encode_plus( - [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, - ) - examples.append(tokenized) + for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"): + tokenized = tokenizer.batch_encode_plus( + [text], # DONT ADD SPACES + max_length=max_length, + pad_to_max_length=pad_to_max_length, + add_prefix_space=True, + return_tensors=return_tensors, + ) + examples.append(tokenized) + torch.save(lmap(dict, examples), cache_path.open("wb")) return examples +def lmap(f, x): + return list(map(f, x)) + + +T5_PREFIX = "summarize: " # HACK, fixme + + def trim_batch( input_ids, pad_token_id, attention_mask=None, ): @@ -30,15 +76,38 @@ class SummarizationDataset(Dataset): def __init__( self, tokenizer, - data_dir="./cnn-dailymail/cnn_dm/", + data_dir, type_path="train", max_source_length=1024, max_target_length=56, + n_obs=None, + overwrite_cache=False, + prefix="", ): super().__init__() - self.tokenizer = tokenizer + tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else "" + self.source = encode_file( + tokenizer, + os.path.join(data_dir, type_path + ".source"), + max_source_length, + overwrite_cache=overwrite_cache, + prefix=prefix, + tok_name=tok_name, + ) + if type_path == "train": + tgt_path = os.path.join(data_dir, type_path + ".target") + else: + tgt_path = os.path.join(data_dir, type_path + ".target") + + self.target = encode_file( + tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name + ) self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) + if n_obs is not None: + self.source = self.source[:n_obs] + self.target = self.target[:n_obs] + self.pad_token_id = tokenizer.pad_token_id def __len__(self): return len(self.source) @@ -47,19 +116,141 @@ class SummarizationDataset(Dataset): source_ids = self.source[index]["input_ids"].squeeze() target_ids = self.target[index]["input_ids"].squeeze() src_mask = self.source[index]["attention_mask"].squeeze() - return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids} + return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids} @staticmethod def trim_seq2seq_batch(batch, pad_token_id): - y = trim_batch(batch["target_ids"], pad_token_id) - source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) + y = trim_batch(batch["decoder_input_ids"], pad_token_id) + source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) return source_ids, source_mask, y - def collate_fn(self, batch): - input_ids = torch.stack([x["source_ids"] for x in batch]) - masks = torch.stack([x["source_mask"] for x in batch]) - target_ids = torch.stack([x["target_ids"] for x in batch]) - pad_token_id = self.tokenizer.pad_token_id + def collate_fn(self, batch) -> dict: + input_ids = torch.stack([x["input_ids"] for x in batch]) + masks = torch.stack([x["attention_mask"] for x in batch]) + target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) + pad_token_id = self.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) - return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y} + batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y} + return batch + + @property + def src_lens(self): # Can delete? + return lmap(len, self.source) + + @property + def tgt_lens(self): + return lmap(len, self.target) + + def make_sortish_sampler(self, batch_size): + return SortishSampler(self.source, batch_size) + + +class SortishSampler(Sampler): + "Go through the text data by order of src length with a bit of randomness. From fastai repo." + + def __init__(self, data, batch_size): + self.data, self.bs = data, batch_size + + def key(self, i): + return len(self.data[i]) + + def __len__(self) -> int: + return len(self.data) + + def __iter__(self): + idxs = np.random.permutation(len(self.data)) + sz = self.bs * 50 + ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] + sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) + sz = self.bs + ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] + max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, + ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. + sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) + sort_idx = np.concatenate((ck_idx[0], sort_idx)) + return iter(sort_idx) + + +def use_task_specific_params(model, task): + # update config with summarization specific params + task_specific_params = model.config.task_specific_params + if task_specific_params is not None: + model.config.update(task_specific_params.get(task, {})) + + +def pickle_load(path): + """pickle.load(path)""" + with open(path, "rb") as f: + return pickle.load(f) + + +def pickle_save(obj, path): + """pickle.dump(obj, path)""" + with open(path, "wb") as f: + return pickle.dump(obj, f) + + +def flatten_list(summary_ids: List[List]): + return [x for x in itertools.chain.from_iterable(summary_ids)] + + +def save_git_info(folder_path: str): + """ + Log commit info. + """ + repo_infos = get_git_info() + + with open(os.path.join(folder_path, "git_log.json"), "w") as f: + json.dump(repo_infos, f, indent=4) + + +def get_git_info(): + repo = git.Repo(search_parent_directories=True) + repo_infos = { + "repo_id": str(repo), + "repo_sha": str(repo.head.object.hexsha), + "repo_branch": str(repo.active_branch), + } + return repo_infos + + +ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] + + +def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict: + scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True) + aggregator = scoring.BootstrapAggregator() + + for reference_ln, output_ln in zip(reference_lns, output_lns): + scores = scorer.score(reference_ln, output_ln) + aggregator.add_scores(scores) + + result = aggregator.aggregate() + return {k: v.mid.fmeasure for k, v in result.items()} + + +def freeze_params(model: nn.Module): + for par in model.parameters(): + par.requires_grad = False + + +def grad_status(model: nn.Module) -> Iterable: + return (par.requires_grad for par in model.parameters()) + + +def any_requires_grad(model: nn.Module) -> bool: + return any(grad_status(model)) + + +def assert_all_frozen(model): + model_grads: List[bool] = list(grad_status(model)) + n_require_grad = sum(lmap(int, model_grads)) + npars = len(model_grads) + assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" + + +def assert_not_all_frozen(model): + model_grads: List[bool] = list(grad_status(model)) + npars = len(model_grads) + assert any(model_grads), f"none of {npars} weights require grad" diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 9eab989ab2..244c8e0b8f 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -59,7 +59,7 @@ BART_GENERATION_EXAMPLE = r""" Examples:: from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig - # see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example + # see ``examples/summarization/bart/run_eval.py`` for a longer example model = BartForConditionalGeneration.from_pretrained('bart-large-cnn') tokenizer = BartTokenizer.from_pretrained('bart-large-cnn') ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."