diff --git a/examples/summarization/bart/README.md b/examples/summarization/bart/README.md index 214eeef134..ca30e8308b 100644 --- a/examples/summarization/bart/README.md +++ b/examples/summarization/bart/README.md @@ -14,6 +14,19 @@ python evaluate_cnn.py cnn_test_summaries.txt ``` the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system. + +### Training + + + +After downloading the CNN and Daily Mail datasets, preprocess the dataset: +```commandline +git clone https://github.com/artmatsak/cnn-dailymail +cd cnn-dailymail && python make_datafiles.py ../cnn/stories/ ../dailymail/stories/ +``` + +Run the training script: `run_train.sh` + ### Where is the code? The core model is in `src/transformers/modeling_bart.py`. This directory only contains examples. diff --git a/examples/summarization/bart/run_bart_sum.py b/examples/summarization/bart/run_bart_sum.py new file mode 100644 index 0000000000..31836ce477 --- /dev/null +++ b/examples/summarization/bart/run_bart_sum.py @@ -0,0 +1,172 @@ +import argparse +import glob +import logging +import os +import time + +import torch +from torch.utils.data import DataLoader + +from transformer_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup +from utils import SummarizationDataset + + +logger = logging.getLogger(__name__) + + +class BartSystem(BaseTransformer): + + mode = "language-modeling" + + def __init__(self, hparams): + super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode) + + def forward( + self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None + ): + return self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + lm_labels=lm_labels, + ) + + def _step(self, batch): + y = batch["target_ids"] + y_ids = y[:, :-1].contiguous() + lm_labels = y[:, 1:].clone() + lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100 + outputs = self( + input_ids=batch["source_ids"], + attention_mask=batch["source_mask"], + decoder_input_ids=y_ids, + lm_labels=lm_labels, + ) + + loss = outputs[0] + + return loss + + def training_step(self, batch, batch_idx): + loss = self._step(batch) + + tensorboard_logs = {"train_loss": loss} + return {"loss": loss, "log": tensorboard_logs} + + def validation_step(self, batch, batch_idx): + loss = self._step(batch) + return {"val_loss": loss} + + def validation_end(self, outputs): + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + tensorboard_logs = {"val_loss": avg_loss} + return {"avg_val_loss": avg_loss, "log": tensorboard_logs} + + def test_step(self, batch, batch_idx): + generated_ids = self.model.generate( + batch["source_ids"], + attention_mask=batch["source_mask"], + num_beams=1, + max_length=80, + repetition_penalty=2.5, + length_penalty=1.0, + early_stopping=True, + ) + preds = [ + self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) + for g in generated_ids + ] + target = [ + self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) + for t in batch["target_ids"] + ] + loss = self._step(batch) + + return {"val_loss": loss, "preds": preds, "target": target} + + def test_end(self, outputs): + return self.validation_end(outputs) + + def test_epoch_end(self, outputs): + output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt") + output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt") + # write predictions and targets for later rouge evaluation. + with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer: + for output_batch in outputs: + p_writer.writelines(s + "\n" for s in output_batch["preds"]) + t_writer.writelines(s + "\n" for s in output_batch["target"]) + p_writer.close() + t_writer.close() + + return self.test_end(outputs) + + def train_dataloader(self): + train_dataset = SummarizationDataset( + self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length + ) + dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size) + t_total = ( + (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) + // self.hparams.gradient_accumulation_steps + * float(self.hparams.num_train_epochs) + ) + scheduler = get_linear_schedule_with_warmup( + self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total + ) + self.lr_scheduler = scheduler + return dataloader + + def val_dataloader(self): + val_dataset = SummarizationDataset( + self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length + ) + return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size) + + def test_dataloader(self): + test_dataset = SummarizationDataset( + self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length + ) + return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size) + + @staticmethod + def add_model_specific_args(parser, root_dir): + BaseTransformer.add_model_specific_args(parser, root_dir) + # Add BART specific options + parser.add_argument( + "--max_seq_length", + default=1024, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) + + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.", + ) + return parser + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + add_generic_args(parser, os.getcwd()) + parser = BartSystem.add_model_specific_args(parser, os.getcwd()) + args = parser.parse_args() + + # If output_dir not provided, a folder will be generated in pwd + if args.output_dir is None: + args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) + os.makedirs(args.output_dir) + + model = BartSystem(args) + trainer = generic_train(model, args) + + # Optionally, predict on dev set and write to output_dir + if args.do_predict: + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) + BartSystem.load_from_checkpoint(checkpoints[-1]) + trainer.test(model) diff --git a/examples/summarization/bart/run_train.sh b/examples/summarization/bart/run_train.sh new file mode 100755 index 0000000000..dfdcb14833 --- /dev/null +++ b/examples/summarization/bart/run_train.sh @@ -0,0 +1,23 @@ +# Install newest ptl. +pip install -U git+http://github.com/PyTorchLightning/pytorch-lightning/ + + +export OUTPUT_DIR_NAME=bart_sum +export CURRENT_DIR=${PWD} +export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME} + +# Make output directory if it doesn't exist +mkdir -p $OUTPUT_DIR + +# Add parent directory to python path to access transformer_base.py +export PYTHONPATH="../../":"${PYTHONPATH}" + +python run_bart_sum.py \ +--data_dir=./cnn-dailymail/cnn_dm \ +--model_type=bart \ +--model_name_or_path=bart-large \ +--learning_rate=3e-5 \ +--train_batch_size=4 \ +--eval_batch_size=4 \ +--output_dir=$OUTPUT_DIR \ +--do_train \ No newline at end of file diff --git a/examples/summarization/bart/utils.py b/examples/summarization/bart/utils.py new file mode 100644 index 0000000000..fbe3c2d4e1 --- /dev/null +++ b/examples/summarization/bart/utils.py @@ -0,0 +1,43 @@ +import os + +from torch.utils.data import Dataset + + +class SummarizationDataset(Dataset): + def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024): + super(SummarizationDataset,).__init__() + self.tokenizer = tokenizer + + self.source = [] + self.target = [] + + print("loading " + type_path + " source.") + + with open(os.path.join(data_dir, type_path + ".source"), "r") as f: + for text in f.readlines(): # each text is a line and a full story + tokenized = tokenizer.batch_encode_plus( + [text], max_length=block_size, pad_to_max_length=True, return_tensors="pt" + ) + self.source.append(tokenized) + f.close() + + print("loading " + type_path + " target.") + + with open(os.path.join(data_dir, type_path + ".target"), "r") as f: + for text in f.readlines(): # each text is a line and a summary + tokenized = tokenizer.batch_encode_plus( + [text], max_length=56, pad_to_max_length=True, return_tensors="pt" + ) + self.target.append(tokenized) + f.close() + + def __len__(self): + return len(self.source) + + def __getitem__(self, index): + source_ids = self.source[index]["input_ids"].squeeze() + target_ids = self.target[index]["input_ids"].squeeze() + + src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze + + return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids} diff --git a/examples/transformer_base.py b/examples/transformer_base.py index 3ffc33c6f0..739c5a3e40 100644 --- a/examples/transformer_base.py +++ b/examples/transformer_base.py @@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule): super(BaseTransformer, self).__init__() self.hparams = hparams self.hparams.model_type = self.hparams.model_type.lower() - config = AutoConfig.from_pretrained( self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, - num_labels=num_labels, + **({"num_labels": num_labels} if num_labels is not None else {}), cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, ) tokenizer = AutoTokenizer.from_pretrained(