examples/seq2seq supports translation (#5202)

This commit is contained in:
Sam Shleifer
2020-06-24 23:58:11 -04:00
committed by GitHub
parent d12ceb48ba
commit 40457bcebb
32 changed files with 626 additions and 636 deletions

169
examples/seq2seq/README.md Normal file
View File

@@ -0,0 +1,169 @@
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Summarization support is more mature than translation support.
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
For `bertabs` instructions, see `bertabs/README.md`.
### Data
CNN/DailyMail data
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
WMT16 English-Romanian Translation Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export ENRO_DIR=${PWD}/wmt_en_ro
```
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Evaluation
To create summaries for each article in dataset, run:
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning
Run/modify `finetune.sh`
The following command should work on a 16GB GPU:
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir=xsum_results \
--num_train_epochs 1 \
--model_name_or_path facebook/bart-large
```
*Note*: The following tips mostly apply to summarization finetuning.
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
#### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
```bash
output_dir
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
│   ├── config.json
│   ├── merges.txt
│   ├── pytorch_model.bin
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── git_log.json # repo, branch, and commit hash
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
├── metrics.json # new validation metrics will continually be appended to this
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json
│   └── pytorch_model.bin
├── test_generations.txt
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
```
After training, you can recover the best checkpoint by running
```python
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
```
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Distilbart
#### No Teacher Distillation
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
You don't even need `distillation.py`.
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
The command that produced `sshleifer/distilbart-cnn-12-6` is
```bash
./train_distilbart_cnn.sh
```
runtime: 6H on NVIDIA RTX 24GB GPU
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyperparameters logged in every run.
#### With a teacher
*Note* only BART variants are supported
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash
./train_distilbart_xsum.sh
```
runtime: 13H on V-100 16GB GPU.
### Contributing
- follow the standard contributing guidelines and code of conduct.
- add tests to `test_seq2seq_examples.py`
- To run only the seq2seq tests, you must be in the root of the repository and run:
```bash
pytest examples/seq2seq/
```

View File

View File

@@ -0,0 +1,61 @@
# Text Summarization with Pretrained Encoders
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
## Setup
```
git clone https://github.com/huggingface/transformers && cd transformers
pip install .
pip install nltk py-rouge
cd examples/seq2seq/bertabs
```
## Reproduce the authors' ROUGE score
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
```
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
```bash
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--no_cuda false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
--beam_size 5 \
--alpha 0.95 \
--block_trigram true \
--compute_rouge true
```
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
## Summarize any text
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
```bash
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--no_cuda false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
--beam_size 5 \
--alpha 0.95 \
--block_trigram true \
```
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.

View File

View File

@@ -0,0 +1,97 @@
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BertAbs configuration """
import logging
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
BERTABS_FINETUNED_CONFIG_MAP = {
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/config.json",
}
class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model.
Arguments:
vocab_size: int
Number of tokens in the vocabulary.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
The numner of hidden layers in the Transformer encoder.
enc_hidden_size: int
The size of the encoder's layers.
enc_heads: int
The number of attention heads for each attention layer in the encoder.
enc_ff_size: int
The size of the encoder's feed-forward layers.
enc_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the encoder.
dec_layer: int
The numner of hidden layers in the decoder.
dec_hidden_size: int
The size of the decoder's layers.
dec_heads: int
The number of attention heads for each attention layer in the decoder.
dec_ff_size: int
The size of the decoder's feed-forward layers.
dec_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the decoder.
"""
model_type = "bertabs"
def __init__(
self,
vocab_size=30522,
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.max_pos = max_pos
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout

View File

@@ -0,0 +1,176 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints.
The script looks like it is doing something trivial but it is not. The "weights"
proposed by the authors are actually the entire model pickled. We need to load
the model within the original codebase to be able to only save its `state_dict`.
"""
import argparse
import logging
from collections import namedtuple
import torch
from model_bertabs import BertAbsSummarizer
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip"
BertAbsConfig = namedtuple(
"BertAbsConfig",
[
"temp_dir",
"large",
"use_bert_emb",
"finetune_bert",
"encoder",
"share_emb",
"max_pos",
"enc_layers",
"enc_hidden_size",
"enc_heads",
"enc_ff_size",
"enc_dropout",
"dec_layers",
"dec_hidden_size",
"dec_heads",
"dec_ff_size",
"dec_dropout",
],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.bert.load_state_dict(original.bert.state_dict())
new_model.decoder.load_state_dict(original.decoder.state_dict())
new_model.generator.load_state_dict(original.generator.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_generator = original.generator(output_original_model)
output_converted_model = new_model(
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
)[0]
output_converted_generator = new_model.generator(output_converted_model)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(
new_model.state_dict(), "./bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
transformers
# For ROUGE
nltk
py-rouge

View File

@@ -0,0 +1,324 @@
#! /usr/bin/python3
import argparse
import logging
import os
import sys
from collections import namedtuple
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from modeling_bertabs import BertAbs, build_predictor
from transformers import BertTokenizer
from .utils_summarization import (
CNNDMDataset,
build_mask,
compute_token_type_ids,
encode_for_summarization,
truncate_or_pad,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
model.to(args.device)
model.eval()
symbols = {
"BOS": tokenizer.vocab["[unused0]"],
"EOS": tokenizer.vocab["[unused1]"],
"PAD": tokenizer.vocab["[PAD]"],
}
if args.compute_rouge:
reference_summaries = []
generated_summaries = []
import rouge
import nltk
nltk.download("punkt")
rouge_evaluator = rouge.Rouge(
metrics=["rouge-n", "rouge-l"],
max_n=2,
limit_length=True,
length_limit=args.beam_size,
length_limit_type="words",
apply_avg=True,
apply_best=False,
alpha=0.5, # Default F1_score
weight_factor=1.2,
stemming=True,
)
# these (unused) arguments are defined to keep the compatibility
# with the legacy code and will be deleted in a next iteration.
args.result_path = ""
args.temp_dir = ""
data_iterator = build_data_iterator(args, tokenizer)
predictor = build_predictor(args, tokenizer, symbols, model)
logger.info("***** Running evaluation *****")
logger.info(" Number examples = %d", len(data_iterator.dataset))
logger.info(" Batch size = %d", args.batch_size)
logger.info("")
logger.info("***** Beam Search parameters *****")
logger.info(" Beam size = %d", args.beam_size)
logger.info(" Minimum length = %d", args.min_length)
logger.info(" Maximum length = %d", args.max_length)
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
for batch in tqdm(data_iterator):
batch_data = predictor.translate_batch(batch)
translations = predictor.from_batch(batch_data)
summaries = [format_summary(t) for t in translations]
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
if args.compute_rouge:
reference_summaries += batch.tgt_str
generated_summaries += summaries
if args.compute_rouge:
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
str_scores = format_rouge_scores(scores)
save_rouge_scores(str_scores)
print(str_scores)
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
original_document_names: List[string]
Name of the document that was summarized.
path: string
Path were the summaries will be written
summaries: List[string]
The summaries that we produced.
"""
for summary, document_name in zip(summaries, original_document_name):
# Prepare the summary file's name
if "." in document_name:
bare_document_name = ".".join(document_name.split(".")[:-1])
extension = document_name.split(".")[-1]
name = bare_document_name + "_summary." + extension
else:
name = document_name + "_summary"
file_path = os.path.join(path, name)
with open(file_path, "w") as output:
output.write(summary)
def format_summary(translation):
""" Transforms the output of the `from_batch` function
into nicely formatted summaries.
"""
raw_summary, _, _ = translation
summary = (
raw_summary.replace("[unused0]", "")
.replace("[unused3]", "")
.replace("[PAD]", "")
.replace("[unused1]", "")
.replace(r" +", " ")
.replace(" [unused2] ", ". ")
.replace("[unused2]", "")
.strip()
)
return summary
def format_rouge_scores(scores):
return """\n
****** ROUGE SCORES ******
** ROUGE 1
F1 >> {:.3f}
Precision >> {:.3f}
Recall >> {:.3f}
** ROUGE 2
F1 >> {:.3f}
Precision >> {:.3f}
Recall >> {:.3f}
** ROUGE L
F1 >> {:.3f}
Precision >> {:.3f}
Recall >> {:.3f}""".format(
scores["rouge-1"]["f"],
scores["rouge-1"]["p"],
scores["rouge-1"]["r"],
scores["rouge-2"]["f"],
scores["rouge-2"]["p"],
scores["rouge-2"]["r"],
scores["rouge-l"]["f"],
scores["rouge-l"]["p"],
scores["rouge-l"]["r"],
)
def save_rouge_scores(str_scores):
with open("rouge_scores.txt", "w") as output:
output.write(str_scores)
#
# LOAD the dataset
#
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
def collate_fn(data):
return collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
return iterator
def load_and_cache_examples(args, tokenizer):
dataset = CNNDMDataset(args.documents_dir)
return dataset
def collate(data, tokenizer, block_size, device):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
all in memory. We output the data as a namedtuple to fit the original BertAbs's
API.
"""
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
names = [name for name, _, _ in data]
summaries = [" ".join(summary_list) for _, _, summary_list in data]
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
encoded_stories = torch.tensor(
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
)
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
batch = Batch(
document_names=names,
batch_size=len(encoded_stories),
src=encoded_stories.to(device),
segs=encoder_token_type_ids.to(device),
mask_src=encoder_mask.to(device),
tgt_str=summaries,
)
return batch
def decode_summary(summary_tokens, tokenizer):
""" Decode the summary and return it in a format
suitable for evaluation.
"""
summary_tokens = summary_tokens.to("cpu").numpy()
summary = tokenizer.decode(summary_tokens)
sentences = summary.split(".")
sentences = [s + "." for s in sentences]
return sentences
def main():
""" The main function defines the interface with the users.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--documents_dir",
default=None,
type=str,
required=True,
help="The folder where the documents to summarize are located.",
)
parser.add_argument(
"--summaries_output_dir",
default=None,
type=str,
required=False,
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
)
parser.add_argument(
"--compute_rouge",
default=False,
type=bool,
required=False,
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
)
# EVALUATION options
parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
)
# BEAM SEARCH arguments
parser.add_argument(
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
)
parser.add_argument(
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
)
parser.add_argument(
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
)
parser.add_argument(
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
)
parser.add_argument(
"--block_trigram",
default=True,
type=bool,
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
)
args = parser.parse_args()
# Select device (distibuted not available)
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
# Check the existence of directories
if not args.summaries_output_dir:
args.summaries_output_dir = args.documents_dir
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
)
os.makedirs(args.summaries_output_dir, exist_ok=True)
evaluate(args)
def documents_dir_is_valid(path):
if not os.path.exists(path):
return False
file_list = os.listdir(path)
if len(file_list) == 0:
return False
return True
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,100 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
class SummarizationDataProcessingTest(unittest.TestCase):
def setUp(self):
self.block_size = 10
def test_fit_to_block_sequence_too_small(self):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_process_story_no_highlights(self):
""" Processing a story with no highlights returns an empty list for the summary.
"""
raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this."""
_, summary_lines = process_story(raw_story)
self.assertEqual(summary_lines, [])
def test_process_empty_story(self):
""" An empty story returns an empty collection of lines.
"""
raw_story = ""
story_lines, summary_lines = process_story(raw_story)
self.assertEqual(story_lines, [])
self.assertEqual(summary_lines, [])
def test_process_story_with_missing_period(self):
raw_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five\n\nSpiritual revelations were conceded to England "
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
)
story_lines, summary_lines = process_story(raw_story)
expected_story_lines = [
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
"Spiritual revelations were conceded to England at that favoured period, as at this.",
]
self.assertEqual(expected_story_lines, story_lines)
expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines)
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
def test_build_mask_with_padding_equal_to_one(self):
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
def test_compute_token_type_ids(self):
separator = 101
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
result = compute_token_type_ids(batch, separator)
np.testing.assert_array_equal(result, expected)

View File

@@ -0,0 +1,167 @@
import os
from collections import deque
import torch
from torch.utils.data import Dataset
# ------------
# Data loading
# ------------
class CNNDMDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
folder. The preprocessing will work on any document that is reasonably
formatted. On the CNN/DailyMail dataset it will extract both the story
and the summary.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, path="", prefix="train"):
""" We initialize the class by listing all the documents to summarize.
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
"""
assert os.path.isdir(path)
self.documents = []
story_filenames_list = os.listdir(path)
for story_filename in story_filenames_list:
if "summary" in story_filename:
continue
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue
self.documents.append(path_to_story)
def __len__(self):
""" Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx):
document_path = self.documents[idx]
document_name = document_path.split("/")[-1]
with open(document_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return document_name, story_lines, summary_lines
def process_story(raw_story):
""" Extract the story and summary from a story file.
Arguments:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the story is empty or contains no highlights.
"""
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None, raising an exception.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
# --------------------------
# Encoding and preprocessing
# --------------------------
def truncate_or_pad(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding token to the right of the sequence.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token_id] * (block_size - len(sequence)))
return sequence
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0
return mask
def encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
return story_token_ids, summary_token_ids
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
for sequence in batch:
sentence_num = -1
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)

View File

@@ -0,0 +1,92 @@
import logging
import os
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
continue
val = metrics[key]
if isinstance(val, torch.Tensor):
val = val.item()
msg = f"{key}: {val:.6f}\n"
writer.write(msg)
if not save_generations:
return
if "preds" in metrics:
content = "\n".join(metrics["preds"])
generations_file.open("w+").write(content)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
try:
npars = pl_module.model.model.num_parameters()
except AttributeError:
npars = pl_module.model.num_parameters()
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "test")
def get_checkpoint_callback(output_dir, metric):
"""Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

View File

@@ -0,0 +1,454 @@
import argparse
import gc
import os
from pathlib import Path
from typing import List
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
try:
from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main
except ImportError:
from finetune import SummarizationModule
from finetune import main as ft_main
from initialization_utils import init_student, copy_layers
from utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
class BartSummarizationDistiller(SummarizationModule):
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
assert Path(hparams.data_dir).exists()
student, student_cfg, teacher = self.pre_init(hparams)
super().__init__(hparams, model=student, config=student_cfg)
self.teacher = teacher
use_task_specific_params(self.teacher, "summarization")
freeze_params(self.teacher)
self.sanity_check_gradients()
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid
# self.alpha_cos = hparams.alpha_cos
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
gc.collect()
torch.cuda.empty_cache()
def sanity_check_gradients(self):
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.model.decoder.embed_tokens)
assert_all_frozen(self.model.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.model.encoder)
else:
freeze_params(self.model.model.encoder)
del self.teacher.model.encoder
def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
if teacher.config.model_type == "t5":
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
if self.different_decoder:
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
if self.different_decoder:
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None:
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
else:
t_logits_slct = teacher_outputs
s_logits_slct = student_outputs
return F.mse_loss(s_logits_slct, t_logits_slct)
def calc_ce_loss(self, mask, s_logits, t_logits):
if mask is not None:
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(s_logits)
s_logits_slct = torch.masked_select(
s_logits, sel_mask
) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(
t_logits, sel_mask
) # (bs * seq_length * voc_size) modulo the 1s in mask
else:
t_logits_slct = t_logits
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return loss_ce, s_logits_slct, t_logits_slct
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
@staticmethod
def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir)
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
return parser
def _step(self, batch):
# assert is_frozen(self.teacher)
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
input_ids, attention_mask=src_mask, output_hidden_states=True
)
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
hid_loss_enc = self.calc_hidden_loss(
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
assert not isinstance(
hidden_states, torch.Tensor
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
assert not isinstance(
hidden_states_T, torch.Tensor
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
hidden_losses = [
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
/ valid_count
for i, j in enumerate(matches)
]
return sum(hidden_losses)
class T5SummarizationDistiller(BartSummarizationDistiller):
def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet")
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
student_updates = {"num_layers": n_layer}
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = T5Config(**kw)
student = T5ForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True)
task_specific_params = student.config.task_specific_params
if task_specific_params is not None:
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def freeze_embeds(self):
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def sanity_check_gradients(self):
"""T5"""
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.decoder.embed_tokens)
assert_all_frozen(self.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.encoder)
else:
freeze_params(self.model.encoder)
del self.teacher.model.encoder
if self.different_decoder:
assert any_requires_grad(self.model.decoder)
else:
freeze_params(self.model.decoder) # TODO(SS): very suspicious
def _step(self, batch):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable
dec_mask = decoder_input_ids.ne(pad_token_id)
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
source_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
use_cache=False,
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
)
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
hid_loss_enc = self.calc_hidden_loss(
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
source_ids,
attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
use_cache=False,
)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
def create_module(args):
t5 = "t5" in args.model_name_or_path
if args.no_teacher:
assert not args.enc_only
module_cls = SummarizationModule
elif t5:
module_cls = T5SummarizationDistiller
elif args.enc_only:
raise ValueError("Deleted that")
else:
module_cls = BartSummarizationDistiller
args.setup_cls: str = module_cls.__name__
model = module_cls(args)
return model
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
exp_dir = ckpt_path.parent
if dest_dir is None:
dest_dir = exp_dir
clash = list(dest_dir.glob("test_generations*"))
if clash:
print(f"SKIPPING to avoid overwriting {clash}")
ckpt = torch.load(ckpt_path, map_location="cpu")
if "hparams" in ckpt:
args = argparse.Namespace(**ckpt["hparams"])
else:
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
args.resume_from_checkpoint = str(ckpt_path)
args.do_train = False
args.output_dir = str(dest_dir)
args.n_gpu = 1
args.eval_batch_size = 16
Path(args.output_dir).mkdir(exist_ok=True)
model = create_module(args)
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
trainer.test(model)
def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot))
if tot == 12: # Alternating for special cases
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
def distill_main(args):
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
model = create_module(args)
return ft_main(args, model=model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
distill_main(args)

View File

@@ -0,0 +1,343 @@
import argparse
import glob
import logging
import os
import time
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
try:
from .utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError:
from utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
metric_names = ROUGE_KEYS
val_metric = "rouge2"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try:
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
except AttributeError:
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
loss = outputs[0]
return (loss,)
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
return parser
class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None:
if args.task == "summarization":
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
logger=logger,
# TODO: early stopping callback seems messed up
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

14
examples/seq2seq/finetune.sh Executable file
View File

@@ -0,0 +1,14 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
--sortish_sampler \
$@

View File

@@ -0,0 +1,32 @@
# Script for verifying that run_bart_sum can be invoked from its directory
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz
tar -xzvf cnn_tiny.tgz
rm cnn_tiny.tgz
export OUTPUT_DIR_NAME=bart_utest_output
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py and utils.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=cnn_tiny/ \
--model_name_or_path=sshleifer/bart-tiny-random \
--learning_rate=3e-5 \
--train_batch_size=2 \
--eval_batch_size=2 \
--output_dir=$OUTPUT_DIR \
--num_train_epochs=1 \
--gpus=0 \
--do_train $@
rm -rf cnn_tiny
rm -rf $OUTPUT_DIR

18
examples/seq2seq/finetune_t5.sh Executable file
View File

@@ -0,0 +1,18 @@
export OUTPUT_DIR_NAME=t5
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_name_or_path=t5-large \
--learning_rate=3e-5 \
--train_batch_size=4 \
--eval_batch_size=4 \
--output_dir=$OUTPUT_DIR \
--do_train $@

View File

@@ -0,0 +1,20 @@
from typing import List
from torch import nn
def init_student(student, teacher):
teacher_state_dict = teacher.state_dict()
info = student.load_state_dict(teacher_state_dict, strict=False)
assert info.missing_keys == [], info.missing_keys
return student, info
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
student_layers.load_state_dict(layers_to_copy.state_dict())

View File

@@ -0,0 +1,10 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.1 \
$@

View File

@@ -0,0 +1,88 @@
import argparse
import json
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try:
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
except ImportError:
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries_or_translations(
examples: list,
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
**gen_kwargs,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# update config with summarization specific params
use_task_specific_params(model, "summarization")
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer.batch_encode_plus(
batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
).to(device)
summaries = model.generate(**batch, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("output_path", type=str, help="where to save summaries")
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
scores = {}
if args.reference_path is not None:
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open("score_path", "w+"))
return scores
if __name__ == "__main__":
run_generate()

View File

@@ -0,0 +1,252 @@
import argparse
import logging
import os
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"logger": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": CUDA_AVAILABLE,
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_loss_encoder": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(**kwargs):
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
return tmp_dir
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
def test_multigpu(self):
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
self._test_distiller_cli(updates)
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
out_path = tempfile.mktemp()
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher=T5_TINY,
model_name_or_path=T5_TINY,
tokenizer_name=T5_TINY,
)
self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents = {os.path.basename(p) for p in contents}
self.assertIn(ckpt_name, contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_dataset(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export BS=32
export GAS=1
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.25 \
--n_val 500 \
--num_train_epochs 2 \
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
--max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--data_dir $CNN_DIR \
--model_name_or_path sshleifer/student_cnn_12_6 \
--tokenizer_name facebook/bart-large \
--output_dir distilbart-cnn-12-6 \
$@

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export BS=16
export GAS=2
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.1 --n_val 1000 \
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--student_decoder_layers 6 --student_encoder_layers 12 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \
--alpha_hid=3. --length_penalty=0.5 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
--tokenizer_name facebook/bart-large \
--output_dir distilbart_xsum_12_6 \
$@

261
examples/seq2seq/utils.py Normal file
View File

@@ -0,0 +1,261 @@
import itertools
import json
import os
import pickle
from pathlib import Path
from typing import Callable, Dict, Iterable, List
import git
import numpy as np
import torch
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
def encode_file(
tokenizer,
data_path,
max_length,
pad_to_max_length=True,
return_tensors="pt",
overwrite_cache=False,
prefix="",
tok_name="",
):
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
examples = torch.load(cache_path)
assert isinstance(examples, list)
return examples
except Exception:
print(f"failed to load from {cache_path}, retokenizing {data_path}")
data_path = Path(data_path)
lns = lmap(str.strip, data_path.open().readlines())
lns = [prefix + text for text in lns]
assert lns, f"found empty file at {data_path}"
examples = []
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer.batch_encode_plus(
[text],
max_length=max_length,
pad_to_max_length=pad_to_max_length,
add_prefix_space=True,
truncation=True,
return_tensors=return_tensors,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def calculate_bleu_score(output_lns, refs_lns) -> dict:
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
def trim_batch(
input_ids, pad_token_id, attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class SummarizationDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
type_path="train",
max_source_length=1024,
max_target_length=56,
n_obs=None,
overwrite_cache=False,
prefix="",
):
super().__init__()
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
max_source_length,
overwrite_cache=overwrite_cache,
prefix=prefix,
tok_name=tok_name,
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]
self.pad_token_id = tokenizer.pad_token_id
def __len__(self):
return len(self.source)
def __getitem__(self, index):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze()
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch) -> dict:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
return batch
@property
def src_lens(self): # Can delete?
return lmap(len, self.source)
@property
def tgt_lens(self):
return lmap(len, self.target)
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.source, batch_size)
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def key(self, i):
return len(self.data[i])
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
idxs = np.random.permutation(len(self.data))
sz = self.bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
sz = self.bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return iter(sort_idx)
def use_task_specific_params(model, task):
# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get(task, {}))
def pickle_load(path):
"""pickle.load(path)"""
with open(path, "rb") as f:
return pickle.load(f)
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path):
with open(path, "w") as f:
json.dump(content, f, indent=4)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
}
return repo_infos
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):
scores = scorer.score(reference_ln, output_ln)
aggregator.add_scores(scores)
result = aggregator.aggregate()
return {k: v.mid.fmeasure for k, v in result.items()}
def freeze_params(model: nn.Module):
for par in model.parameters():
par.requires_grad = False
def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters())
def any_requires_grad(model: nn.Module) -> bool:
return any(grad_status(model))
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"