examples/seq2seq supports translation (#5202)
This commit is contained in:
169
examples/seq2seq/README.md
Normal file
169
examples/seq2seq/README.md
Normal file
@@ -0,0 +1,169 @@
|
||||
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
|
||||
Summarization support is more mature than translation support.
|
||||
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
|
||||
For `bertabs` instructions, see `bertabs/README.md`.
|
||||
|
||||
### Data
|
||||
|
||||
CNN/DailyMail data
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||
tar -xzvf cnn_dm.tgz
|
||||
|
||||
export CNN_DIR=${PWD}/cnn_dm
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
|
||||
|
||||
WMT16 English-Romanian Translation Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
||||
tar -xzvf wmt_en_ro.tar.gz
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro
|
||||
```
|
||||
|
||||
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||
The `.source` files are the input, the `.target` files are the desired output.
|
||||
|
||||
### Evaluation
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||
```
|
||||
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
The following command should work on a 16GB GPU:
|
||||
```bash
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--train_batch_size=1 \
|
||||
--eval_batch_size=1 \
|
||||
--output_dir=xsum_results \
|
||||
--num_train_epochs 1 \
|
||||
--model_name_or_path facebook/bart-large
|
||||
```
|
||||
|
||||
*Note*: The following tips mostly apply to summarization finetuning.
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
|
||||
#### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
|
||||
```bash
|
||||
output_dir
|
||||
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
|
||||
│ ├── config.json
|
||||
│ ├── merges.txt
|
||||
│ ├── pytorch_model.bin
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── git_log.json # repo, branch, and commit hash
|
||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
|
||||
├── metrics.json # new validation metrics will continually be appended to this
|
||||
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── test_generations.txt
|
||||
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
|
||||
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
|
||||
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
|
||||
```
|
||||
After training, you can recover the best checkpoint by running
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
|
||||
### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||
|
||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||
```bash
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir xsum_frozen_embs \
|
||||
--model_name_or_path facebook/bart-large \
|
||||
--logger wandb_shared \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
|
||||
```
|
||||
|
||||
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||
|
||||
|
||||
### Distilbart
|
||||
|
||||
#### No Teacher Distillation
|
||||
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
|
||||
You don't even need `distillation.py`.
|
||||
|
||||
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
|
||||
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
|
||||
The command that produced `sshleifer/distilbart-cnn-12-6` is
|
||||
```bash
|
||||
./train_distilbart_cnn.sh
|
||||
```
|
||||
runtime: 6H on NVIDIA RTX 24GB GPU
|
||||
|
||||
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
|
||||
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
||||
because you will have the same hyperparameters logged in every run.
|
||||
|
||||
#### With a teacher
|
||||
*Note* only BART variants are supported
|
||||
|
||||
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
||||
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
|
||||
|
||||
The command that produced `sshleifer/distilbart-xsum-12-6` is:
|
||||
|
||||
```bash
|
||||
./train_distilbart_xsum.sh
|
||||
```
|
||||
|
||||
runtime: 13H on V-100 16GB GPU.
|
||||
|
||||
### Contributing
|
||||
- follow the standard contributing guidelines and code of conduct.
|
||||
- add tests to `test_seq2seq_examples.py`
|
||||
- To run only the seq2seq tests, you must be in the root of the repository and run:
|
||||
```bash
|
||||
pytest examples/seq2seq/
|
||||
```
|
||||
0
examples/seq2seq/__init__.py
Normal file
0
examples/seq2seq/__init__.py
Normal file
61
examples/seq2seq/bertabs/README.md
Normal file
61
examples/seq2seq/bertabs/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Text Summarization with Pretrained Encoders
|
||||
|
||||
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
|
||||
|
||||
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
|
||||
|
||||
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk py-rouge
|
||||
cd examples/seq2seq/bertabs
|
||||
```
|
||||
|
||||
## Reproduce the authors' ROUGE score
|
||||
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
```
|
||||
|
||||
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
```
|
||||
|
||||
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||
0
examples/seq2seq/bertabs/__init__.py
Normal file
0
examples/seq2seq/bertabs/__init__.py
Normal file
97
examples/seq2seq/bertabs/configuration_bertabs.py
Normal file
97
examples/seq2seq/bertabs/configuration_bertabs.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BertAbs configuration """
|
||||
import logging
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r""" Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
Number of tokens in the vocabulary.
|
||||
max_pos: int
|
||||
The maximum sequence length that this model will be used with.
|
||||
enc_layer: int
|
||||
The numner of hidden layers in the Transformer encoder.
|
||||
enc_hidden_size: int
|
||||
The size of the encoder's layers.
|
||||
enc_heads: int
|
||||
The number of attention heads for each attention layer in the encoder.
|
||||
enc_ff_size: int
|
||||
The size of the encoder's feed-forward layers.
|
||||
enc_dropout: int
|
||||
The dropout probabilitiy for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the encoder.
|
||||
dec_layer: int
|
||||
The numner of hidden layers in the decoder.
|
||||
dec_hidden_size: int
|
||||
The size of the decoder's layers.
|
||||
dec_heads: int
|
||||
The number of attention heads for each attention layer in the decoder.
|
||||
dec_ff_size: int
|
||||
The size of the decoder's feed-forward layers.
|
||||
dec_dropout: int
|
||||
The dropout probabilitiy for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.enc_layers = enc_layers
|
||||
self.enc_hidden_size = enc_hidden_size
|
||||
self.enc_heads = enc_heads
|
||||
self.enc_ff_size = enc_ff_size
|
||||
self.enc_dropout = enc_dropout
|
||||
|
||||
self.dec_layers = dec_layers
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.dec_heads = dec_heads
|
||||
self.dec_ff_size = dec_ff_size
|
||||
self.dec_dropout = dec_dropout
|
||||
@@ -0,0 +1,176 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Convert BertExtAbs's checkpoints.
|
||||
|
||||
The script looks like it is doing something trivial but it is not. The "weights"
|
||||
proposed by the authors are actually the entire model pickled. We need to load
|
||||
the model within the original codebase to be able to only save its `state_dict`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
# Instantiate the authors' model with the pre-trained weights
|
||||
config = BertAbsConfig(
|
||||
temp_dir=".",
|
||||
finetune_bert=False,
|
||||
large=False,
|
||||
share_emb=True,
|
||||
use_bert_emb=False,
|
||||
encoder="bert",
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
)
|
||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||
original.eval()
|
||||
|
||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||
new_model.eval()
|
||||
|
||||
# -------------------
|
||||
# Convert the weights
|
||||
# -------------------
|
||||
|
||||
logging.info("convert the model")
|
||||
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||
|
||||
# ----------------------------------
|
||||
# Make sure the outpus are identical
|
||||
# ----------------------------------
|
||||
|
||||
logging.info("Make sure that the models' outputs are identical")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# prepare the model inputs
|
||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||
|
||||
# failsafe to make sure the weights reset does not affect the
|
||||
# loaded weights.
|
||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||
|
||||
# forward pass
|
||||
src = encoder_input_ids
|
||||
tgt = decoder_input_ids
|
||||
segs = token_type_ids = None
|
||||
clss = None
|
||||
mask_src = encoder_attention_mask = None
|
||||
mask_tgt = decoder_attention_mask = None
|
||||
mask_cls = None
|
||||
|
||||
# The original model does not apply the geneator layer immediatly but rather in
|
||||
# the beam search (where it combines softmax + linear layer). Since we already
|
||||
# apply the softmax in our generation process we only apply the linear layer here.
|
||||
# We make sure that the outputs of the full stack are identical
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
|
||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||
if are_identical:
|
||||
logging.info("all weights are equal up to 1e-3")
|
||||
else:
|
||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(
|
||||
new_model.state_dict(), "./bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
|
||||
)
|
||||
1026
examples/seq2seq/bertabs/modeling_bertabs.py
Normal file
1026
examples/seq2seq/bertabs/modeling_bertabs.py
Normal file
File diff suppressed because it is too large
Load Diff
5
examples/seq2seq/bertabs/requirements.txt
Normal file
5
examples/seq2seq/bertabs/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
transformers
|
||||
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
324
examples/seq2seq/bertabs/run_summarization.py
Normal file
324
examples/seq2seq/bertabs/run_summarization.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#! /usr/bin/python3
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
CNNDMDataset,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
encode_for_summarization,
|
||||
truncate_or_pad,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||
model = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
symbols = {
|
||||
"BOS": tokenizer.vocab["[unused0]"],
|
||||
"EOS": tokenizer.vocab["[unused1]"],
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
args.temp_dir = ""
|
||||
|
||||
data_iterator = build_data_iterator(args, tokenizer)
|
||||
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info("")
|
||||
logger.info("***** Beam Search parameters *****")
|
||||
logger.info(" Beam size = %d", args.beam_size)
|
||||
logger.info(" Minimum length = %d", args.min_length)
|
||||
logger.info(" Maximum length = %d", args.max_length)
|
||||
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||
|
||||
for batch in tqdm(data_iterator):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
""" Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
original_document_names: List[string]
|
||||
Name of the document that was summarized.
|
||||
path: string
|
||||
Path were the summaries will be written
|
||||
summaries: List[string]
|
||||
The summaries that we produced.
|
||||
"""
|
||||
for summary, document_name in zip(summaries, original_document_name):
|
||||
# Prepare the summary file's name
|
||||
if "." in document_name:
|
||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||
extension = document_name.split(".")[-1]
|
||||
name = bare_document_name + "_summary." + extension
|
||||
else:
|
||||
name = document_name + "_summary"
|
||||
|
||||
file_path = os.path.join(path, name)
|
||||
with open(file_path, "w") as output:
|
||||
output.write(summary)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
""" Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
.replace("[PAD]", "")
|
||||
.replace("[unused1]", "")
|
||||
.replace(r" +", " ")
|
||||
.replace(" [unused2] ", ". ")
|
||||
.replace("[unused2]", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
#
|
||||
# LOAD the dataset
|
||||
#
|
||||
|
||||
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
def collate_fn(data):
|
||||
return collate(data, tokenizer, block_size=512, device=args.device)
|
||||
|
||||
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = CNNDMDataset(args.documents_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
""" Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
API.
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
""" Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
summary = tokenizer.decode(summary_tokens)
|
||||
sentences = summary.split(".")
|
||||
sentences = [s + "." for s in sentences]
|
||||
return sentences
|
||||
|
||||
|
||||
def main():
|
||||
""" The main function defines the interface with the users.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the documents to summarize are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summaries_output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Select device (distibuted not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
if not documents_dir_is_valid(args.documents_dir):
|
||||
raise FileNotFoundError(
|
||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||
)
|
||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||
|
||||
evaluate(args)
|
||||
|
||||
|
||||
def documents_dir_is_valid(path):
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
|
||||
file_list = os.listdir(path)
|
||||
if len(file_list) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
100
examples/seq2seq/bertabs/test_utils_summarization.py
Normal file
100
examples/seq2seq/bertabs/test_utils_summarization.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
|
||||
|
||||
|
||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights returns an empty list for the summary.
|
||||
"""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story returns an empty collection of lines.
|
||||
"""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
167
examples/seq2seq/bertabs/utils_summarization.py
Normal file
167
examples/seq2seq/bertabs/utils_summarization.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Data loading
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||
and the summary.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
""" We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
|
||||
self.documents = []
|
||||
story_filenames_list = os.listdir(path)
|
||||
for story_filename in story_filenames_list:
|
||||
if "summary" in story_filename:
|
||||
continue
|
||||
path_to_story = os.path.join(path, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
""" Returns the number of documents. """
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
document_path = self.documents[idx]
|
||||
document_name = document_path.split("/")[-1]
|
||||
with open(document_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return document_name, story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the story is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None, raising an exception.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Encoding and preprocessing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = -1
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
92
examples/seq2seq/callbacks.py
Normal file
92
examples/seq2seq/callbacks.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation ROUGE2 score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=1,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
454
examples/seq2seq/distillation.py
Normal file
454
examples/seq2seq/distillation.py
Normal file
@@ -0,0 +1,454 @@
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
assert Path(hparams.data_dir).exists()
|
||||
student, student_cfg, teacher = self.pre_init(hparams)
|
||||
|
||||
super().__init__(hparams, model=student, config=student_cfg)
|
||||
self.teacher = teacher
|
||||
use_task_specific_params(self.teacher, "summarization")
|
||||
freeze_params(self.teacher)
|
||||
self.sanity_check_gradients()
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.temperature = 2.0
|
||||
self.alpha_mlm = hparams.alpha_mlm
|
||||
self.alpha_ce = hparams.alpha_ce
|
||||
self.alpha_hid = hparams.alpha_hid
|
||||
# self.alpha_cos = hparams.alpha_cos
|
||||
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def sanity_check_gradients(self):
|
||||
assert_all_frozen(self.teacher)
|
||||
assert_all_frozen(self.model.model.decoder.embed_tokens)
|
||||
assert_all_frozen(self.model.model.encoder.embed_tokens)
|
||||
if self.different_encoder:
|
||||
assert any_requires_grad(self.model.model.encoder)
|
||||
else:
|
||||
freeze_params(self.model.model.encoder)
|
||||
del self.teacher.model.encoder
|
||||
|
||||
def pre_init(self, hparams):
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||
student_updates = {
|
||||
"decoder_layers": hparams.student_decoder_layers,
|
||||
"encoder_layers": hparams.student_encoder_layers,
|
||||
}
|
||||
if hparams.length_penalty != -1:
|
||||
student_updates["length_penalty"] = hparams.length_penalty
|
||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
kw = teacher.config.to_diff_dict()
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
student_cfg = BartConfig(**kw)
|
||||
student = BartForConditionalGeneration(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
student.save_pretrained(save_dir)
|
||||
hparams.model_name_or_path = str(save_dir)
|
||||
return student, student_cfg, teacher
|
||||
|
||||
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||
if teacher.config.model_type == "t5":
|
||||
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
|
||||
if self.different_decoder:
|
||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||
|
||||
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
|
||||
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
|
||||
if self.different_decoder:
|
||||
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
||||
return dataset
|
||||
|
||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
|
||||
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
|
||||
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
|
||||
else:
|
||||
t_logits_slct = teacher_outputs
|
||||
s_logits_slct = student_outputs
|
||||
return F.mse_loss(s_logits_slct, t_logits_slct)
|
||||
|
||||
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||
s_logits_slct = torch.masked_select(
|
||||
s_logits, sel_mask
|
||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(
|
||||
t_logits, sel_mask
|
||||
) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
else:
|
||||
t_logits_slct = t_logits
|
||||
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
return loss_ce, s_logits_slct, t_logits_slct
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
|
||||
return parser
|
||||
|
||||
def _step(self, batch):
|
||||
# assert is_frozen(self.teacher)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
input_ids, src_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = y[:, :-1].contiguous()
|
||||
labels = y[:, 1:].clone()
|
||||
labels[y[:, 1:] == pad_token_id] = -100
|
||||
# noinspection PyCallingNonCallable
|
||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(sloss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder(
|
||||
input_ids, attention_mask=src_mask, output_hidden_states=True
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
|
||||
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||
)
|
||||
|
||||
teacher_enc_outputs = (enc_outputs,)
|
||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||
|
||||
with torch.no_grad():
|
||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * sloss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
|
||||
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches):
|
||||
assert not isinstance(
|
||||
hidden_states, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states, got tensor of shape {hidden_states.shape}"
|
||||
assert not isinstance(
|
||||
hidden_states_T, torch.Tensor
|
||||
), f"expected list or tuple for hidden_states_T, got tensor of shape {hidden_states_T.shape}"
|
||||
mask = attention_mask.to(hidden_states[0])
|
||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||
hidden_losses = [
|
||||
(F.mse_loss(hidden_states[i], hidden_states_T[j], reduction="none") * mask.unsqueeze(-1)).sum()
|
||||
/ valid_count
|
||||
for i, j in enumerate(matches)
|
||||
]
|
||||
return sum(hidden_losses)
|
||||
|
||||
|
||||
class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
def pre_init(self, hparams):
|
||||
raise NotImplementedError("T5 Distillation does not work yet")
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||
n_layer = hparams.student_decoder_layers
|
||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
|
||||
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
||||
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
||||
student_updates = {"num_layers": n_layer}
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
kw = teacher.config.to_diff_dict()
|
||||
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
student_cfg = T5Config(**kw)
|
||||
student = T5ForConditionalGeneration(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||
task_specific_params = student.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
|
||||
student.save_pretrained(save_dir)
|
||||
hparams.model_name_or_path = str(save_dir)
|
||||
return student, student_cfg, teacher
|
||||
|
||||
def freeze_embeds(self):
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def sanity_check_gradients(self):
|
||||
"""T5"""
|
||||
assert_all_frozen(self.teacher)
|
||||
assert_all_frozen(self.model.decoder.embed_tokens)
|
||||
assert_all_frozen(self.model.encoder.embed_tokens)
|
||||
if self.different_encoder:
|
||||
assert any_requires_grad(self.model.encoder)
|
||||
else:
|
||||
freeze_params(self.model.encoder)
|
||||
del self.teacher.model.encoder
|
||||
if self.different_decoder:
|
||||
assert any_requires_grad(self.model.decoder)
|
||||
else:
|
||||
freeze_params(self.model.decoder) # TODO(SS): very suspicious
|
||||
|
||||
def _step(self, batch):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = y[:, :-1].contiguous()
|
||||
labels = y[:, 1:].clone()
|
||||
labels[y[:, 1:] == pad_token_id] = -100
|
||||
# noinspection PyCallingNonCallable
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
|
||||
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
labels=labels,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(sloss)
|
||||
|
||||
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
||||
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
||||
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
|
||||
)
|
||||
|
||||
teacher_enc_outputs = (enc_outputs,)
|
||||
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
|
||||
|
||||
with torch.no_grad():
|
||||
tloss, tlogits, tdec_hidden, _ = self.teacher(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
encoder_outputs=teacher_enc_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
lm_labels=labels,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
|
||||
if self.alpha_hid > 0:
|
||||
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_layer_to_copy)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * sloss
|
||||
+ self.hparams.alpha_encoder_loss * loss_encoder
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
|
||||
|
||||
|
||||
def create_module(args):
|
||||
t5 = "t5" in args.model_name_or_path
|
||||
if args.no_teacher:
|
||||
assert not args.enc_only
|
||||
module_cls = SummarizationModule
|
||||
elif t5:
|
||||
module_cls = T5SummarizationDistiller
|
||||
elif args.enc_only:
|
||||
raise ValueError("Deleted that")
|
||||
else:
|
||||
module_cls = BartSummarizationDistiller
|
||||
args.setup_cls: str = module_cls.__name__
|
||||
model = module_cls(args)
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
exp_dir = ckpt_path.parent
|
||||
if dest_dir is None:
|
||||
dest_dir = exp_dir
|
||||
clash = list(dest_dir.glob("test_generations*"))
|
||||
if clash:
|
||||
print(f"SKIPPING to avoid overwriting {clash}")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
if "hparams" in ckpt:
|
||||
args = argparse.Namespace(**ckpt["hparams"])
|
||||
else:
|
||||
args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
|
||||
args.resume_from_checkpoint = str(ckpt_path)
|
||||
args.do_train = False
|
||||
args.output_dir = str(dest_dir)
|
||||
args.n_gpu = 1
|
||||
args.eval_batch_size = 16
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
model = create_module(args)
|
||||
trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
def get_layers_to_copy(n_to_get, tot):
|
||||
all_layers = list(range(tot))
|
||||
if tot == 12: # Alternating for special cases
|
||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 6],
|
||||
3: [0, 6, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
else:
|
||||
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
model = create_module(args)
|
||||
return ft_main(args, model=model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
distill_main(args)
|
||||
343
examples/seq2seq/finetune.py
Normal file
343
examples/seq2seq/finetune.py
Normal file
@@ -0,0 +1,343 @@
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationModule(BaseTransformer):
|
||||
mode = "summarization"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ROUGE_KEYS
|
||||
val_metric = "rouge2"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=self.model.config.prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
try:
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
except AttributeError:
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
y_ids = y[:, :-1].contiguous()
|
||||
lm_labels = y[:, 1:].clone()
|
||||
lm_labels[y[:, 1:] == pad_token_id] = -100
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
||||
loss = outputs[0]
|
||||
return (loss,)
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
|
||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
||||
rouges.update({k: v.item() for k, v in losses.items()})
|
||||
losses.update(rouges)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_rouge(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = SummarizationDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
sampler = None
|
||||
if self.hparams.sortish_sampler and type_path == "train":
|
||||
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
||||
sampler = dataset.make_sortish_sampler(batch_size)
|
||||
shuffle = False
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.gradient_accumulation_steps
|
||||
* float(self.hparams.num_train_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=1024,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=56,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=142,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class TranslationModule(SummarizationModule):
|
||||
mode = "translation"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if model is None:
|
||||
if args.task == "summarization":
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
|
||||
elif args.logger == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
14
examples/seq2seq/finetune.sh
Executable file
14
examples/seq2seq/finetune.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
--sortish_sampler \
|
||||
$@
|
||||
32
examples/seq2seq/finetune_bart_tiny.sh
Executable file
32
examples/seq2seq/finetune_bart_tiny.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
# Script for verifying that run_bart_sum can be invoked from its directory
|
||||
|
||||
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz
|
||||
tar -xzvf cnn_tiny.tgz
|
||||
rm cnn_tiny.tgz
|
||||
|
||||
export OUTPUT_DIR_NAME=bart_utest_output
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py and utils.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
python finetune.py \
|
||||
--data_dir=cnn_tiny/ \
|
||||
--model_name_or_path=sshleifer/bart-tiny-random \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=2 \
|
||||
--eval_batch_size=2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--num_train_epochs=1 \
|
||||
--gpus=0 \
|
||||
--do_train $@
|
||||
|
||||
rm -rf cnn_tiny
|
||||
rm -rf $OUTPUT_DIR
|
||||
|
||||
|
||||
|
||||
18
examples/seq2seq/finetune_t5.sh
Executable file
18
examples/seq2seq/finetune_t5.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
export OUTPUT_DIR_NAME=t5
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_name_or_path=t5-large \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train $@
|
||||
20
examples/seq2seq/initialization_utils.py
Normal file
20
examples/seq2seq/initialization_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import List
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def init_student(student, teacher):
|
||||
teacher_state_dict = teacher.state_dict()
|
||||
info = student.load_state_dict(teacher_state_dict, strict=False)
|
||||
assert info.missing_keys == [], info.missing_keys
|
||||
return student, info
|
||||
|
||||
|
||||
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
|
||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
|
||||
|
||||
|
||||
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
|
||||
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
|
||||
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
|
||||
student_layers.load_state_dict(layers_to_copy.state_dict())
|
||||
10
examples/seq2seq/run_distiller.sh
Executable file
10
examples/seq2seq/run_distiller.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 \
|
||||
--val_check_interval 0.1 \
|
||||
$@
|
||||
88
examples/seq2seq/run_eval.py
Normal file
88
examples/seq2seq/run_eval.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||
except ImportError:
|
||||
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries_or_translations(
|
||||
examples: list,
|
||||
out_file: str,
|
||||
model_name: str,
|
||||
batch_size: int = 8,
|
||||
device: str = DEFAULT_DEVICE,
|
||||
fp16=False,
|
||||
**gen_kwargs,
|
||||
) -> None:
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model_name = str(model_name)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# update config with summarization specific params
|
||||
use_task_specific_params(model, "summarization")
|
||||
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
batch = tokenizer.batch_encode_plus(
|
||||
batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
|
||||
).to(device)
|
||||
summaries = model.generate(**batch, **gen_kwargs)
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
|
||||
parser.add_argument("output_path", type=str, help="where to save summaries")
|
||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
||||
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
|
||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries_or_translations(
|
||||
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
||||
)
|
||||
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
scores = {}
|
||||
if args.reference_path is not None:
|
||||
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
scores: dict = score_fn(output_lns, reference_lns)
|
||||
if args.score_path is not None:
|
||||
json.dump(scores, open("score_path", "w+"))
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
||||
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import SummarizationDataset, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": CUDA_AVAILABLE,
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"num_train_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_loss_encoder": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
with path.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(**kwargs):
|
||||
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
||||
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||
def test_multigpu(self):
|
||||
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_checkpointing_with_teacher(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp()
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher=T5_TINY,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=T5_TINY,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
self.assertIn(ckpt_name, contents)
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval_bart(model):
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(tmp, articles)
|
||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_dataset(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
export BS=32
|
||||
export GAS=1
|
||||
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.25 \
|
||||
--n_val 500 \
|
||||
--num_train_epochs 2 \
|
||||
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
||||
--max_target_length 142 --val_max_target_length=142 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--data_dir $CNN_DIR \
|
||||
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--output_dir distilbart-cnn-12-6 \
|
||||
$@
|
||||
|
||||
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export BS=16
|
||||
export GAS=2
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 \
|
||||
--val_check_interval 0.1 --n_val 1000 \
|
||||
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--student_decoder_layers 6 --student_encoder_layers 12 \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--model_name_or_path IGNORED \
|
||||
--alpha_hid=3. --length_penalty=0.5 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--output_dir distilbart_xsum_12_6 \
|
||||
$@
|
||||
261
examples/seq2seq/utils.py
Normal file
261
examples/seq2seq/utils.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
data_path,
|
||||
max_length,
|
||||
pad_to_max_length=True,
|
||||
return_tensors="pt",
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
examples = torch.load(cache_path)
|
||||
assert isinstance(examples, list)
|
||||
return examples
|
||||
|
||||
except Exception:
|
||||
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
||||
data_path = Path(data_path)
|
||||
|
||||
lns = lmap(str.strip, data_path.open().readlines())
|
||||
lns = [prefix + text for text in lns]
|
||||
assert lns, f"found empty file at {data_path}"
|
||||
examples = []
|
||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text],
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
add_prefix_space=True,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||
return examples
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def calculate_bleu_score(output_lns, refs_lns) -> dict:
|
||||
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class SummarizationDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
type_path="train",
|
||||
max_source_length=1024,
|
||||
max_target_length=56,
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
max_source_length,
|
||||
overwrite_cache=overwrite_cache,
|
||||
prefix=prefix,
|
||||
tok_name=tok_name,
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
if n_obs is not None:
|
||||
self.source = self.source[:n_obs]
|
||||
self.target = self.target[:n_obs]
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source)
|
||||
|
||||
def __getitem__(self, index):
|
||||
source_ids = self.source[index]["input_ids"].squeeze()
|
||||
target_ids = self.target[index]["input_ids"].squeeze()
|
||||
src_mask = self.source[index]["attention_mask"].squeeze()
|
||||
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id):
|
||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||
return source_ids, source_mask, y
|
||||
|
||||
def collate_fn(self, batch) -> dict:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
pad_token_id = self.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
||||
return batch
|
||||
|
||||
@property
|
||||
def src_lens(self): # Can delete?
|
||||
return lmap(len, self.source)
|
||||
|
||||
@property
|
||||
def tgt_lens(self):
|
||||
return lmap(len, self.target)
|
||||
|
||||
def make_sortish_sampler(self, batch_size):
|
||||
return SortishSampler(self.source, batch_size)
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
|
||||
def __init__(self, data, batch_size):
|
||||
self.data, self.bs = data, batch_size
|
||||
|
||||
def key(self, i):
|
||||
return len(self.data[i])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __iter__(self):
|
||||
idxs = np.random.permutation(len(self.data))
|
||||
sz = self.bs * 50
|
||||
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
||||
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
|
||||
sz = self.bs
|
||||
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
||||
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
||||
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
||||
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
||||
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
||||
return iter(sort_idx)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get(task, {}))
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
"""pickle.load(path)"""
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
|
||||
def save_json(content, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=4)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
return repo_infos
|
||||
|
||||
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
scores = scorer.score(reference_ln, output_ln)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
result = aggregator.aggregate()
|
||||
return {k: v.mid.fmeasure for k, v in result.items()}
|
||||
|
||||
|
||||
def freeze_params(model: nn.Module):
|
||||
for par in model.parameters():
|
||||
par.requires_grad = False
|
||||
|
||||
|
||||
def grad_status(model: nn.Module) -> Iterable:
|
||||
return (par.requires_grad for par in model.parameters())
|
||||
|
||||
|
||||
def any_requires_grad(model: nn.Module) -> bool:
|
||||
return any(grad_status(model))
|
||||
|
||||
|
||||
def assert_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
n_require_grad = sum(lmap(int, model_grads))
|
||||
npars = len(model_grads)
|
||||
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
||||
|
||||
|
||||
def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
Reference in New Issue
Block a user