Reorganize examples (#9010)

* Reorganize example folder

* Continue reorganization

* Change requirements for tests

* Final cleanup

* Finish regroup with tests all passing

* Copyright

* Requirements and readme

* Make a full link for the documentation

* Address review comments

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add symlink

* Reorg again

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Adapt title

* Update to new strucutre

* Remove test

* Update READMEs

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger
2020-12-11 10:07:02 -05:00
committed by GitHub
parent 86896de064
commit 783d7d2629
215 changed files with 4454 additions and 1193 deletions

View File

@@ -0,0 +1,430 @@
## Sequence to Sequence Training and Evaluation
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Author: Sam Shleifer (https://github.com/sshleifer)
### Supported Architectures
- `BartForConditionalGeneration` (and anything that inherits from it)
- `MarianMTModel`
- `PegasusForConditionalGeneration`
- `MBartForConditionalGeneration`
- `FSMTForConditionalGeneration`
- `T5ForConditionalGeneration`
## Datasets
#### XSUM
```bash
cd examples/contrib/pytorch-lightning/seq2seq
wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
this should make a directory called `xsum/` with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
#### CNN/DailyMail
```bash
cd examples/contrib/pytorch-lightning/seq2seq
wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz
tar -xzvf cnn_dm_v2.tgz # empty lines removed
mv cnn_cln cnn_dm
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called `cnn_dm/` with 6 files.
#### WMT16 English-Romanian Translation Data
download with this command:
```bash
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export ENRO_DIR=${PWD}/wmt_en_ro
```
this should make a directory called `wmt_en_ro/` with 6 files.
#### WMT English-German
```bash
wget https://cdn-datasets.huggingface.co/translation/wmt_en_de.tgz
tar -xzvf wmt_en_de.tgz
export DATA_DIR=${PWD}/wmt_en_de
```
#### FSMT datasets (wmt)
Refer to the scripts starting with `eval_` under:
https://github.com/huggingface/transformers/tree/master/scripts/fsmt
#### Pegasus (multiple datasets)
Multiple eval datasets are available for download from:
https://github.com/stas00/porting/tree/master/datasets/pegasus
#### Your Data
If you are using your own data, it must be formatted as one directory with 6 files:
```
train.source
train.target
val.source
val.target
test.source
test.target
```
The `.source` files are the input, the `.target` files are the desired output.
### Potential issues
- native AMP (`--fp16` and no apex) may lead to a huge memory leak and require 10x gpu memory. This has been fixed in pytorch-nightly and the minimal official version to have this fix will be pytorch-1.8. Until then if you have to use mixed precision please use AMP only with pytorch-nightly or NVIDIA's apex. Reference: https://github.com/huggingface/transformers/issues/8403
### Tips and Tricks
General Tips:
- since you need to run from this folder, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
- Read scripts before you run them!
Summarization Tips:
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18**
Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prepare_seq2seq_batch` method. Otherwise, `Seq2SeqDataset` will be used.
Future work/help wanted: A new dataset to support multilingual tasks.
### Finetuning Scripts
All finetuning bash scripts call finetune.py (or distillation.py) with reasonable command line arguments. They usually require extra command line arguments to work.
To see all the possible command line options, run:
```bash
./finetune.py --help
```
### Finetuning Training Params
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
```bash
./finetune.sh \
[...]
--encoder_layerdrop 0.1 \
--decoder_layerdrop 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
```
### Summarization Finetuning
Run/modify `finetune.sh`
The following command should work on a 16GB GPU:
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir=xsum_results \
--num_train_epochs 6 \
--model_name_or_path facebook/bart-large
```
There is a starter finetuning script for pegasus at `finetune_pegasus_xsum.sh`.
### Translation Finetuning
First, follow the wmt_en_ro download instructions.
Then you can finetune mbart_cc25 on english-romanian with the following command.
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
Best performing command:
```bash
# optionally
export ENRO_DIR='wmt_en_ro' # Download instructions above
# export WANDB_PROJECT="MT" # optional
export MAX_LEN=128
export BS=4
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
```
This should take < 6h/epoch on a 16GB v100 and achieve test BLEU above 26
To get results in line with fairseq, you need to do some postprocessing. (see `romanian_postprocessing.md`)
MultiGPU command
(using 8 GPUS as an example)
```bash
export ENRO_DIR='wmt_en_ro' # Download instructions above
# export WANDB_PROJECT="MT" # optional
export MAX_LEN=128
export BS=4
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
```
### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
```bash
output_dir
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
│   ├── config.json
│   ├── merges.txt
│   ├── pytorch_model.bin
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── git_log.json # repo, branch, and commit hash
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
├── metrics.json # new validation metrics will continually be appended to this
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json
│   └── pytorch_model.bin
├── test_generations.txt
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
```
After training, you can recover the best checkpoint by running
```python
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
### Converting pytorch-lightning checkpoints
pytorch lightning ``-do_predict`` often fails, after you are done training, the best way to evaluate your model is to convert it.
This should be done for you, with a file called `{save_dir}/best_tfmr`.
If that file doesn't exist but you have a lightning `.ckpt` file, you can run
```bash
python convert_pl_checkpoint_to_hf.py PATH_TO_CKPT randomly_initialized_hf_model_path save_dir/best_tfmr
```
Then either `run_eval` or `run_distributed_eval` with `save_dir/best_tfmr` (see previous sections)
# Experimental Features
These features are harder to use and not always useful.
### Dynamic Batch Size for MT
`finetune.py` has a command line arg `--max_tokens_per_batch` that allows batches to be dynamically sized.
This feature can only be used:
- with fairseq installed
- on 1 GPU
- without sortish sampler
- after calling `./save_len_file.py $tok $data_dir`
For example,
```bash
./save_len_file.py Helsinki-NLP/opus-mt-en-ro wmt_en_ro
./dynamic_bs_example.sh --max_tokens_per_batch=2000 --output_dir benchmark_dynamic_bs
```
splits `wmt_en_ro/train` into 11,197 uneven lengthed batches and can finish 1 epoch in 8 minutes on a v100.
For comparison,
```bash
./dynamic_bs_example.sh --sortish_sampler --train_batch_size 48
```
uses 12,723 batches of length 48 and takes slightly more time 9.5 minutes.
The feature is still experimental, because:
+ we can make it much more robust if we have memory mapped/preprocessed datasets.
+ The speedup over sortish sampler is not that large at the moment.
# DistilBART
<!---It should be called distilling bart and pegasus, but I don't want to break the link in the paper.-->
This section describes all code and artifacts from our [Paper](http://arxiv.org/abs/2010.13002)
![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
+ For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works, which we call "Shrink and Fine-tune", or SFT.
you just copy alternating layers from `facebook/bart-large-cnn` and fine-tune more on the cnn/dm data. `sshleifer/distill-pegasus-cnn-16-4`, `sshleifer/distilbart-cnn-12-6` and all other checkpoints under `sshleifer` that start with `distilbart-cnn` were trained this way.
+ For the XSUM dataset, training on pseudo-labels worked best for Pegasus (`sshleifer/distill-pegasus-16-4`), while training with KD worked best for `distilbart-xsum-12-6`
+ For `sshleifer/dbart-xsum-12-3`
+ We ran 100s experiments, and didn't want to document 100s of commands. If you want a command to replicate a figure from the paper that is not documented below, feel free to ask on the [forums](https://discuss.huggingface.co/t/seq2seq-distillation-methodology-questions/1270) and tag `@sshleifer`.
+ You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0).
and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23).
### Evaluation
use [run_distributed_eval](./run_distributed_eval.py), with the following convenient alias
```bash
deval () {
proc=$1
m=$2
dd=$3
sd=$4
shift
shift
shift
shift
python -m torch.distributed.launch --nproc_per_node=$proc run_distributed_eval.py \
--model_name $m --save_dir $sd --data_dir $dd $@
}
```
On a 1 GPU system, here are four commands (that assume `xsum`, `cnn_dm` are downloaded, cmd-F for those links in this file).
`distilBART`:
```bash
deval 1 sshleifer/distilbart-xsum-12-3 xsum dbart_12_3_xsum_eval --fp16 # --help for more choices.
deval 1 sshleifer/distilbart-cnn_dm-12-6 cnn_dm dbart_12_6_cnn_eval --fp16
```
`distill-pegasus`:
```bash
deval 1 sshleifer/distill-pegasus-cnn-16-4 cnn_dm dpx_cnn_eval
deval 1 sshleifer/distill-pegasus-xsum-16-4 xsum dpx_xsum_eval
```
### Distillation
+ For all of the following commands, you can get roughly equivalent result and faster run times by passing `--num_beams=4`. That's not what we did for the paper.
+ Besides the KD section, you can also run commands with the built-in transformers trainer. See, for example, [builtin_trainer/train_distilbart_cnn.sh](./builtin_trainer/train_distilbart_cnn.sh).
+ Large performance deviations (> 5X slower or more than 0.5 Rouge-2 worse), should be reported.
+ Multi-gpu (controlled with `--gpus` should work, but might require more epochs).
#### Recommended Workflow
+ Get your dataset in the right format. (see 6 files above).
+ Find a teacher model [Pegasus](https://huggingface.co/models?search=pegasus) (slower, better ROUGE) or `facebook/bart-large-xsum`/`facebook/bart-large-cnn` (faster, slightly lower.).
Choose the checkpoint where the corresponding dataset is most similar (or identical to) your dataset.
+ Follow the sections in order below. You can stop after SFT if you are satisfied, or move on to pseudo-labeling if you want more performance.
+ student size: If you want a close to free 50% speedup, cut the decoder in half. If you want a larger speedup, cut it in 4.
+ If your SFT run starts at a validation ROUGE-2 that is more than 10 pts below the teacher's validation ROUGE-2, you have a bug. Switching to a more expensive technique will not help. Try setting a breakpoint and looking at generation and truncation defaults/hyper-parameters, and share your experience on the forums!
#### Initialization
We use [make_student.py](./make_student.py) to copy alternating layers from the teacher, and save the resulting model to disk
```bash
python make_student.py facebook/bart-large-xsum --save_path dbart_xsum_12_3 -e 12 -d 3
```
or for `pegasus-xsum`
```bash
python make_student.py google/pegasus-xsum --save_path dpx_xsum_16_4 --e 16 --d 4
```
we now have an initialized student saved to `dbart_xsum_12_3`, which we will use for the following commands.
+ Extension: To replicate more complicated initialize experiments in section 6.1, or try your own. Use the `create_student_by_copying_alternating_layers` function.
#### Pegasus
+ The following commands are written for BART and will require, at minimum, the following modifications
+ reduce batch size, and increase gradient accumulation steps so that the product `gpus * batch size * gradient_accumulation_steps = 256`. We used `--learning-rate` = 1e-4 * gradient accumulation steps.
+ don't use fp16
+ `--tokenizer_name google/pegasus-large`
### SFT (No Teacher Distillation)
You don't need `distillation.py`, you can just run:
```bash
python finetune.py \
--data_dir xsum \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--model_name_or_path dbart_xsum_12_3 \
--train_batch_size=64 --eval_batch_size=64 \
--sortish_sampler \
--num_train_epochs=6 \
--warmup_steps 500 \
--output_dir distilbart_xsum_sft_12_3 --gpus 1
```
+ Note: The command that produced `sshleifer/distilbart-cnn-12-6` is at [train_distilbart_cnn.sh](./[train_distilbart_cnn.sh)
```bash
./train_distilbart_cnn.sh
```
<!--- runtime: 6H on NVIDIA RTX 24GB GPU -->
+ Tip: You can get the same simple distillation logic by using `distillation.py --no_teacher ` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyper-parameters logged in every run.
### Pseudo-Labeling
+ You don't need `distillation.py`.
+ Instructions to generate pseudo-labels and use pre-computed pseudo-labels can be found [here](./precomputed_pseudo_labels.md).
Simply run `finetune.py` with one of those pseudo-label datasets as `--data_dir` (`DATA`, below).
```bash
python finetune.py \
--teacher facebook/bart-large-xsum --data_dir DATA \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--model_name_or_path dbart_xsum_12_3 \
--train_batch_size=32 --eval_batch_size=32 \
--sortish_sampler \
--num_train_epochs=5 \
--warmup_steps 500 \
--output_dir dbart_xsum_12_3_PL --gpus 1 --logger_name wandb
```
To combine datasets, as in Section 6.2, try something like:
```bash
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz | tar -xvz -C .
curl -S https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz | tar -xvz -C .
mkdir all_pl
cat bart_xsum_pl/train.source pegasus_xsum/train.source xsum/train.source > all_pl/train.source
cat bart_xsum_pl/train.target pegasus_xsum/train.target xsum/train.target > all_pl/train.target
cp xsum/val* all_pl
cp xsum/test* all_pl
```
then use `all_pl` as DATA in the command above.
#### Direct Knowledge Distillation (KD)
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `SummarizationDistiller`.
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
The command that produced `sshleifer/distilbart-xsum-12-6` is at [./train_distilbart_xsum.sh](train_distilbart_xsum.sh)
```bash
./train_distilbart_xsum.sh --logger_name wandb --gpus 1
```
+ Expected ROUGE-2 between 21.3 and 21.6, run time ~13H.
+ direct KD + Pegasus is VERY slow and works best with `--supervise_forward --normalize_hidden`.
<!--- runtime: 13H on V-100 16GB GPU. -->
### Citation
```bibtex
@misc{shleifer2020pretrained,
title={Pre-trained Summarization Distillation},
author={Sam Shleifer and Alexander M. Rush},
year={2020},
eprint={2010.13002},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{Wolf2019HuggingFacesTS,
title={HuggingFace's Transformers: State-of-the-art Natural Language Processing},
author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush},
journal={ArXiv},
year={2019},
volume={abs/1910.03771}
}
```

View File

@@ -0,0 +1,203 @@
#!/usr/bin/env python
import argparse
import os
import sys
from unittest.mock import patch
import pytorch_lightning as pl
import timeout_decorator
import torch
from distillation import SummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from transformers import MarianMTModel
from transformers.file_utils import cached_path
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
from utils import load_json
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
class TestMbartCc25Enro(TestCasePlus):
def setUp(self):
super().setUp()
data_cached = cached_path(
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
extract_compressed_file=True,
)
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
@slow
@require_torch_gpu
def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
MarianMTModel.from_pretrained(MARIAN_MODEL)
# @timeout_decorator.timeout(1200)
@slow
@require_torch_gpu
def test_train_mbart_cc25_enro_script(self):
env_vars_to_replace = {
"$MAX_LEN": 64,
"$BS": 64,
"$GAS": 1,
"$ENRO_DIR": self.data_dir,
"facebook/mbart-large-cc25": MARIAN_MODEL,
# "val_check_interval=0.25": "val_check_interval=1.0",
"--learning_rate=3e-5": "--learning_rate 3e-4",
"--num_train_epochs 6": "--num_train_epochs 1",
}
# Clean up bash script
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v))
output_dir = self.get_auto_remove_tmp_dir()
# bash_script = bash_script.replace("--fp16 ", "")
args = f"""
--output_dir {output_dir}
--tokenizer_name Helsinki-NLP/opus-mt-en-ro
--sortish_sampler
--do_predict
--gpus 1
--freeze_encoder
--n_train 40000
--n_val 500
--n_test 500
--fp16_opt_level O1
--num_sanity_val_steps 0
--eval_beams 2
""".split()
# XXX: args.gpus > 1 : handle multi_gpu in the future
testargs = ["finetune.py"] + bash_script.split() + args
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
model = main(args)
# Check metrics
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
# test learning requirements:
# 1. BLEU improves over the course of training by more than 2 pts
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)
# 2. BLEU finishes above 17
self.assertGreater(last_step_stats["val_avg_bleu"], 17)
# 3. test BLEU and val BLEU within ~1.1 pt.
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed.
if args.do_predict:
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
class TestDistilMarianNoTeacher(TestCasePlus):
@timeout_decorator.timeout(600)
@slow
@require_torch_gpu
def test_opus_mt_distill_script(self):
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 16,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"$m": "sshleifer/student_marian_en_ro_6_1",
"val_check_interval=0.25": "val_check_interval=1.0",
}
# Clean up bash script
bash_script = (
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v))
output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16", "")
epochs = 6
testargs = (
["distillation.py"]
+ bash_script.split()
+ [
f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=1e-3",
f"--num_train_epochs={epochs}",
"--warmup_steps=10",
"--val_check_interval=1.0",
"--do_predict",
]
)
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
# assert args.gpus == gpus THIS BREAKS for multi_gpu
model = distill_main(args)
# Check metrics
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
full_path = os.path.join(args.output_dir, ckpt_path)
ckpt = torch.load(full_path, map_location="cpu")
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
assert expected_key in ckpt["state_dict"]
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
# TODO: turn on args.do_predict when PL bug fixed.
if args.do_predict:
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.txt" in contents
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1

View File

@@ -0,0 +1,44 @@
import tempfile
import unittest
from make_student import create_student_by_copying_alternating_layers
from transformers import AutoConfig
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, require_torch_non_multi_gpu_but_fix_me
TINY_BART = "sshleifer/bart-tiny-random"
TINY_T5 = "patrickvonplaten/t5-tiny-random"
@require_torch
class MakeStudentTester(unittest.TestCase):
@cached_property
def teacher_config(self):
return AutoConfig.from_pretrained(TINY_BART)
@require_torch_non_multi_gpu_but_fix_me
def test_valid_t5(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
self.assertEqual(student.config.num_hidden_layers, 1)
@require_torch_non_multi_gpu_but_fix_me
def test_asymmetric_t5(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
@require_torch_non_multi_gpu_but_fix_me
def test_same_decoder_small_encoder(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
self.assertEqual(student.config.encoder_layers, 1)
self.assertEqual(student.config.decoder_layers, self.teacher_config.encoder_layers)
@require_torch_non_multi_gpu_but_fix_me
def test_small_enc_small_dec(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=1)
self.assertEqual(student.config.encoder_layers, 1)
self.assertEqual(student.config.decoder_layers, 1)
@require_torch_non_multi_gpu_but_fix_me
def test_raises_assert(self):
with self.assertRaises(AssertionError):
create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=None, d=None)

View File

@@ -0,0 +1,443 @@
import argparse
import logging
import os
import sys
import tempfile
from pathlib import Path
import pytest
import pytorch_lightning as pl
import torch
import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main
from finetune import SummarizationModule, main
from parameterized import parameterized
from run_eval import generate_summaries_or_translations
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
from utils import label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"max_tokens_per_batch": None,
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"freeze_encoder": False,
"auto_scale_batch_size": False,
"overwrite_output_dir": False,
"student": None,
}
def _dump_articles(path: Path, articles: list):
content = "\n".join(articles)
Path(path).open("w").writelines(content)
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
T5_TINIER = "sshleifer/t5-tinier-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
FSMT_TINY = "stas/tiny-wmt19-en-de"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(tmp_dir):
for split in ["train", "val", "test"]:
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir
class TestSummarizationDistiller(TestCasePlus):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@slow
@require_torch_gpu
def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().model_list()
org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
failures = []
for m in model_ids:
if m in allowed_to_be_broken:
continue
try:
AutoConfig.from_pretrained(m)
except Exception:
failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
max_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
out_path = tempfile.mktemp() # XXX: not being cleaned up
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists())
out_path_new = self.get_auto_remove_tmp_dir()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
def test_loss_fn(self):
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
model_computed_loss = model(
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
).loss
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
smoothed_loss, nll_loss = label_smoothed_nll_loss(
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
)
with self.assertRaises(AssertionError):
# TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
task="translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY,
teacher=MBART_TINY,
src_lang="en_XX",
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2)
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher=T5_TINY,
model_name_or_path=T5_TINY,
tokenizer_name=T5_TINY,
)
self._test_distiller_cli(updates)
def test_distill_different_base_models(self):
updates = dict(
teacher=T5_TINY,
student=T5_TINIER,
model_name_or_path=T5_TINIER,
tokenizer_name=T5_TINIER,
)
self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
ckpt_files = [p for p in contents if p.endswith("ckpt")]
assert len(ckpt_files) > 0
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model
class TestTheRest(TestCasePlus):
@parameterized.expand(
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
)
def test_finetune(self, model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
elif model == FSMT_TINY:
fsmt = module.model.model
embed_pos = fsmt.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not fsmt.decoder.embed_tokens.weight.requires_grad
# check that embeds are not the same
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
def test_finetune_extra_model_args(self):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
args_d.update(
data_dir=tmp_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# test models whose config includes the extra_model_args
model = BART_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d1 = args_d.copy()
args_d1.update(
model_name_or_path=model,
output_dir=output_dir,
)
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
args_d1[p] = 0.5
args = argparse.Namespace(**args_d1)
model = main(args)
for p in extra_model_params:
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
# test models whose config doesn't include the extra_model_args
model = T5_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d2 = args_d.copy()
args_d2.update(
model_name_or_path=model,
output_dir=output_dir,
)
unsupported_param = "encoder_layerdrop"
args_d2[unsupported_param] = 0.5
args = argparse.Namespace(**args_d2)
with pytest.raises(Exception) as excinfo:
model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_schedulers(self):
args_d: dict = CHEAP_ARGS.copy()
task = "summarization"
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
model = BART_TINY
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
output_dir=output_dir,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
do_predict=False,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
# emulate finetune.py
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = {"--help": True}
# --help test
with pytest.raises(SystemExit) as excinfo:
with CaptureStdout() as cs:
args = parser.parse_args(args)
assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit
expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers"
# --lr_scheduler=non_existing_scheduler test
unsupported_param = "non_existing_scheduler"
args = {f"--lr_scheduler={unsupported_param}"}
with pytest.raises(SystemExit) as excinfo:
with CaptureStderr() as cs:
args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit
expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
# --lr_scheduler=existing_scheduler test
supported_param = "cosine"
args_d1 = args_d.copy()
args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1)
model = main(args)
assert (
getattr(model.hparams, "lr_scheduler") == supported_param
), f"lr_scheduler={supported_param} shouldn't fail"

View File

@@ -0,0 +1,164 @@
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
import os
import sys
from pathlib import Path
import torch
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
from utils import load_json
CUDA_AVAILABLE = torch.cuda.is_available()
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
CHEAP_ARGS = {
"max_tokens_per_batch": None,
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"accumulate_grad_batches": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"lr_scheduler": "linear",
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"max_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"freeze_encoder": False,
"auto_scale_batch_size": False,
"overwrite_output_dir": False,
"student": None,
}
def _dump_articles(path: Path, articles: list):
content = "\n".join(articles)
Path(path).open("w").writelines(content)
def make_test_data_dir(tmp_dir):
for split in ["train", "val", "test"]:
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir
class TestSummarizationDistillerMultiGPU(TestCasePlus):
@classmethod
def setUpClass(cls):
return cls
@require_torch_multi_gpu
def test_multi_gpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
overwrite_output_dir=True,
sortish_sampler=True,
)
self._test_distiller_cli_fork(updates, check_contents=False)
def _test_distiller_cli_fork(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
early_stopping_patience=-1,
train_batch_size=1,
eval_batch_size=2,
max_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
def convert(k, v):
if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
return ""
if v is False or v is None:
return ""
if v is True: # or len(str(v))==0:
return f"--{k}"
return f"--{k}={v}"
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
execute_subprocess_async(cmd, env=self.get_env())
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
ckpt_files = [p for p in contents if p.endswith("ckpt")]
assert len(ckpt_files) > 0
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
# get the following from the module, (we don't have access to `model` here)
metrics_save_path = os.path.join(output_dir, "metrics.json")
val_metric = "rouge2"
metrics = load_json(metrics_save_path)
# {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
print(metrics)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
self.assertEqual(len(metrics["test"]), 1)
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)

View File

@@ -0,0 +1,115 @@
import logging
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils import save_json
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
continue
val = metrics[key]
if isinstance(val, torch.Tensor):
val = val.item()
msg = f"{key}: {val:.6f}\n"
writer.write(msg)
if not save_generations:
return
if "preds" in metrics:
content = "\n".join(metrics["preds"])
generations_file.open("w+").write(content)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
try:
npars = pl_module.model.model.num_parameters()
except AttributeError:
npars = pl_module.model.num_parameters()
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
save_json(pl_module.metrics, pl_module.metrics_save_path)
return self._write_logs(trainer, pl_module, "test")
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module):
save_json(pl_module.metrics, pl_module.metrics_save_path)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
"""Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
elif metric == "loss":
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="min" if "loss" in metric else "max",
save_top_k=save_top_k,
)
return checkpoint_callback
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}", # does this need avg?
mode="min" if "loss" in metric else "max",
patience=patience,
verbose=True,
)

View File

@@ -0,0 +1,74 @@
#!/usr/bin/env python
import os
from pathlib import Path
from typing import Dict, List
import fire
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.utils.logging import get_logger
logger = get_logger(__name__)
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text # or whatever
def sanitize(sd):
return {remove_prefix(k, "model."): v for k, v in sd.items()}
def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
new_sd = {}
for k in state_dicts[0].keys():
tensors = [sd[k] for sd in state_dicts]
new_t = sum(tensors) / len(tensors)
assert isinstance(new_t, torch.Tensor)
new_sd[k] = new_t
return new_sd
def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
"""Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!
Args:
pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
If a directory is passed, all .ckpt files inside it will be averaged!
hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
save_path (:obj:`str`): Directory to save the new model
"""
hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
if os.path.isfile(pl_ckpt_path):
ckpt_files = [pl_ckpt_path]
else:
assert os.path.isdir(pl_ckpt_path)
ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"
if len(ckpt_files) > 1:
logger.info(f"averaging the weights of {ckpt_files}")
state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
state_dict = average_state_dicts(state_dicts)
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
assert not missing, f"missing keys: {missing}"
hf_model.save_pretrained(save_path)
try:
tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
tok.save_pretrained(save_path)
except Exception:
pass
# dont copy tokenizer if cant
if __name__ == "__main__":
fire.Fire(convert_pl_to_hf)

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
# export MAX_LEN=128
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--fp16 \
--val_check_interval 0.25 \
--teacher Helsinki-NLP/opus-mt-en-ro \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--student_decoder_layers 3 --student_encoder_layers 6 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \
--alpha_hid=3. \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --logger_name wandb \
--fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \
"$@"

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --no_teacher \
--val_check_interval 0.25 \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--freeze_encoder --freeze_embeds \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \
"$@"

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python
import argparse
import gc
import os
import sys
from pathlib import Path
from typing import List
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train # noqa
class SummarizationDistiller(SummarizationModule):
"""Supports T5, Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
assert Path(hparams.data_dir).exists()
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
save_dir = self.output_dir.joinpath("student")
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
if hparams.student is not None:
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
use_task_specific_params(student, hparams.task)
e_layer_ids, d_layer_ids = None, None
else:
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)
if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config)
assert (
student.config.model_type == teacher.config.model_type
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
if student.config.model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)
student_decoder_layers = len(student.get_decoder().block)
teacher_encoder_layers = len(teacher.get_encoder().block)
teacher_decoder_layers = len(teacher.get_decoder().block)
else:
student_encoder_layers = student.config.encoder_layers
student_decoder_layers = student.config.decoder_layers
teacher_encoder_layers = teacher.config.encoder_layers
teacher_decoder_layers = teacher.config.decoder_layers
self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student)
self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0
self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers)
# self.different_encoder determines whether we need to run the teacher encoder
self.teacher = teacher
freeze_params(self.teacher)
if not self.different_encoder: # To save RAM, delete teacher encoder and freeze student encoder.
try:
del self.teacher.model.encoder
except AttributeError: # T5
del self.teacher.encoder
if e_layer_ids is None:
e_layer_ids = list(range(student_encoder_layers))
if d_layer_ids is None:
d_layer_ids = list(range(student_decoder_layers))
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
if self.do_calc_hidden_loss: # Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
)
self.d_matches = get_layers_to_supervise(
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids
else:
self.e_matches = None
self.d_matches = None
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid
gc.collect()
torch.cuda.empty_cache()
def calc_ce_loss(self, mask, s_logits, t_logits):
"""Copy pasted from distillbert (transformers/examples/distillation/)"""
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(s_logits)
vocab_size = s_logits.size(-1)
s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = (
self.ce_loss_fct(
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
F.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return loss_ce
@staticmethod
def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir)
add_distill_args(parser)
return parser
def _step(self, batch: dict) -> tuple:
"""Compute the loss for a batch"""
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(labels)
else:
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
# noinspection PyCallingNonCallable
student_outputs = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
output_hidden_states=self.do_calc_hidden_loss,
output_attentions=False,
use_cache=False,
)
lm_logits = student_outputs["logits"]
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else:
lprobs = F.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)
teacher_enc_outputs = student_outputs[
"encoder_last_hidden_state"
] # use this unless self.different_base_models
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
all_teacher_encoder_outputs = self.teacher.get_encoder()(
input_ids,
attention_mask=src_mask,
output_hidden_states=self.do_calc_hidden_loss,
)
if self.different_base_models:
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
elif self.do_calc_hidden_loss:
hid_loss_enc = self.calc_hidden_loss(
src_mask,
student_outputs["encoder_hidden_states"],
all_teacher_encoder_outputs["hidden_states"],
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
teacher_outputs = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=(teacher_enc_outputs,),
decoder_input_ids=decoder_input_ids,
output_hidden_states=self.do_calc_hidden_loss,
use_cache=False, # since we are not passing labels, never let this default to True
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(
dec_mask,
student_outputs["decoder_hidden_states"],
teacher_outputs["decoder_hidden_states"],
self.d_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
@staticmethod
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
"""MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
msg = "expected list or tuple for hidden_states, got tensor of shape: "
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
mask = attention_mask.to(hidden_states[0])
valid_count = mask.sum() * hidden_states[0].size(-1)
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none")
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
return masked_mse
def add_distill_args(parser):
# NOTE: if --student argument was specified and the teacher and student base models
# are different, the models still have to have the same tokenizer, specified by
# --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not
# from bart to t5. This s because if the tokenizers are different, the output space
# for the two models is also different and their logits are not comparable.
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student", type=str, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
parser.add_argument("--supervise_forward", action="store_true", default=False)
parser.add_argument("--normalize_hidden", action="store_true", default=False)
class TranslationDistiller(SummarizationDistiller):
"""Supports T5, mBART, Marian, other models that inherit from Bart."""
mode = "translation"
metric_names = ["bleu"]
default_val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
assert hparams.src_lang is not None
assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu(preds, target)
@staticmethod
def add_model_specific_args(parser, root_dir):
TranslationModule.add_model_specific_args(parser, root_dir)
add_distill_args(parser)
return parser
def create_module(args):
if args.no_teacher:
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
else: # DISTILL WITH TEACHER
module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller
args.setup_cls: str = module_cls.__name__
print(f"using module {args.setup_cls}")
model = module_cls(args)
return model
def distill_main(args):
Path(args.output_dir).mkdir(exist_ok=True)
check_output_dir(args, expected_items=3)
model = create_module(args)
return ft_main(args, model=model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
distill_main(args)

View File

@@ -0,0 +1,17 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export WANDB_PROJECT=dmar
export MAX_LEN=128
export m=sshleifer/student_marian_en_ro_6_1
python finetune.py \
--learning_rate=3e-4 \
--do_train \
--fp16 \
--data_dir wmt_en_ro \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--freeze_encoder --freeze_embeds \
--train_batch_size=48 --eval_batch_size=64 \
--tokenizer_name $m --model_name_or_path $m --num_train_epochs=1 \
--warmup_steps 500 --logger_name wandb --gpus 1 \
--fp16_opt_level=O1 --task translation \
"$@"

View File

@@ -0,0 +1,442 @@
#!/usr/bin/env python
import argparse
import glob
import logging
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
check_output_dir,
flatten_list,
freeze_embeds,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
metric_names = ROUGE_KEYS
default_val_metric = "rouge2"
def __init__(self, hparams, **kwargs):
if hparams.sortish_sampler and hparams.gpus > 1:
hparams.replace_sampler_ddp = False
elif hparams.max_tokens_per_batch is not None:
if hparams.gpus > 1:
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
if hparams.sortish_sampler:
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.model_type = self.config.model_type
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
freeze_embeds(self.model)
if self.hparams.freeze_encoder:
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None # default to config
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
self.already_saved_batch = False
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
"""A debugging utility"""
readable_batch = {
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
}
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
self.already_saved_batch = True
return readable_batch
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
tgt_ids = batch["labels"]
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(tgt_ids)
else:
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
batch["decoder_input_ids"] = decoder_input_ids
self.save_readable_batch(batch)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs["logits"]
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
assert lm_logits.shape[-1] == self.vocab_size
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
)
return (loss,)
@property
def pad(self) -> int:
return self.tokenizer.pad_token_id
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
logs["bs"] = batch["input_ids"].shape[0]
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
# TODO(SS): make a wandb summary metric for this
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
generative_metrics = {
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
}
metric_val = (
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
)
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
generative_metrics.update({k: v.item() for k, v in losses.items()})
losses.update(generative_metrics)
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
all_metrics["step_count"] = self.step_count
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {
"log": all_metrics,
"preds": preds,
f"{prefix}_loss": loss,
f"{prefix}_{self.val_metric}": metric_tensor,
}
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)
def _generative_step(self, batch: dict) -> dict:
t0 = time.time()
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
num_beams=self.eval_beams,
max_length=self.eval_max_length,
)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["labels"])
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=False,
num_workers=self.num_workers,
sampler=sampler,
)
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
batch_sampler = dataset.make_dynamic_sampler(
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=dataset.collate_fn,
# shuffle=False,
num_workers=self.num_workers,
# batch_size=None,
)
else:
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=None,
)
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument("--eval_beams", type=int, default=None, required=False)
parser.add_argument(
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
)
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
return parser
class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
default_val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu(preds, target)
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
check_output_dir(args, expected_items=3)
if model is None:
if "summarization" in args.task:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
if args.early_stopping_patience >= 0:
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
else:
es_callback = False
lower_is_better = args.val_metric == "loss"
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(
args.output_dir, model.val_metric, args.save_top_k, lower_is_better
),
early_stopping_callback=es_callback,
logger=logger,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,11 @@
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./finetune.sh --help to see all the possible options
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
"$@"

View File

@@ -0,0 +1,32 @@
# Script for verifying that run_bart_sum can be invoked from its directory
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
wget https://cdn-datasets.huggingface.co/summarization/cnn_tiny.tgz
tar -xzvf cnn_tiny.tgz
rm cnn_tiny.tgz
export OUTPUT_DIR_NAME=bart_utest_output
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py and testing_utils.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=cnn_tiny/ \
--model_name_or_path=sshleifer/bart-tiny-random \
--learning_rate=3e-5 \
--train_batch_size=2 \
--eval_batch_size=2 \
--output_dir=$OUTPUT_DIR \
--num_train_epochs=1 \
--gpus=0 \
--do_train "$@"
rm -rf cnn_tiny
rm -rf $OUTPUT_DIR

View File

@@ -0,0 +1,14 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
# From appendix C of paper https://arxiv.org/abs/1912.08777
# Set --gradient_accumulation_steps so that effective batch size is 256 (2*128, 4*64, 8*32, 16*16)
python finetune.py \
--learning_rate=1e-4 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.25 \
--max_source_length 512 --max_target_length 56 \
--freeze_embeds --label_smoothing 0.1 --adafactor --task summarization_xsum \
"$@"

View File

@@ -0,0 +1,14 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=$CNN_DIR \
--learning_rate=3e-5 \
--train_batch_size=$BS \
--eval_batch_size=$BS \
--output_dir=$OUTPUT_DIR \
--max_source_length=512 \
--max_target_length=56 \
--val_check_interval=0.1 --n_val=200 \
--do_train --do_predict \
"$@"

View File

@@ -0,0 +1,391 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Any, Dict
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
)
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from transformers.utils.versions import require_version_examples
logger = logging.getLogger(__name__)
require_version_examples("pytorch_lightning>=1.0.4")
MODEL_MODES = {
"base": AutoModel,
"sequence-classification": AutoModelForSequenceClassification,
"question-answering": AutoModelForQuestionAnswering,
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"translation": AutoModelForSeq2SeqLM,
}
# update this and the import above to support new schedulers from transformers.optimization
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
# '': get_constant_schedule, # not supported for now
# '': get_constant_schedule_with_warmup, # not supported for now
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
class BaseTransformer(pl.LightningModule):
def __init__(
self,
hparams: argparse.Namespace,
num_labels=None,
mode="base",
config=None,
tokenizer=None,
model=None,
**config_kwargs
):
"""Initialize a model, tokenizer and config."""
super().__init__()
# TODO: move to self.save_hyperparameters()
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.save_hyperparameters(hparams)
self.step_count = 0
self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
else:
self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
self.model_type = MODEL_MODES[mode]
if model is None:
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
else:
self.model = model
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer
scheduler = self.get_lr_scheduler()
return [optimizer], [scheduler]
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.dataset_size = len(self.train_dataloader().dataset)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
def test_dataloader(self):
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
def _feature_file(self, mode):
return os.path.join(
self.hparams.data_dir,
"cached_{}_{}_{}".format(
mode,
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
str(self.hparams.max_seq_length),
),
)
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("best_tfmr")
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
)
parser.add_argument(
"--encoder_layerdrop",
type=float,
help="Encoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--decoder_layerdrop",
type=float,
help="Decoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--dropout",
type=float,
help="Dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--attention_dropout",
type=float,
help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--lr_scheduler",
default="linear",
choices=arg_to_scheduler_choices,
metavar=arg_to_scheduler_metavar,
type=str,
help="Learning rate scheduler",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
# Log results
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
def add_generic_args(parser, root_dir) -> None:
# To allow all pl args uncomment the following line
# parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
):
pl.seed_everything(args.seed)
# init model
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if early_stopping_callback:
extra_callbacks.append(early_stopping_callback)
if logging_callback is None:
logging_callback = LoggingCallback()
train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params,
)
if args.do_train:
trainer.fit(model)
return trainer

View File

@@ -0,0 +1,173 @@
import warnings
from pathlib import Path
from typing import List, Tuple, Union
import fire
from torch import nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
dest_layers.load_state_dict(layers_to_copy.state_dict())
LAYERS_TO_COPY = {
# maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 15],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
LAYERS_TO_SUPERVISE = {
# maps num layers in student -> which teacher layers to copy.
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
}
def pick_layers_to_copy(n_student, n_teacher):
try:
val = LAYERS_TO_COPY[n_teacher][n_student]
return val
except KeyError:
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
"""Used or the --supervise_forward kwarg"""
if n_student > n_teacher:
raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
elif n_teacher == n_student:
return list(range(n_teacher))
elif n_student == 1:
return [n_teacher - 1]
else:
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
def create_student_by_copying_alternating_layers(
teacher: Union[str, PreTrainedModel],
save_path: Union[str, Path] = "student",
e: Union[int, None] = None,
d: Union[int, None] = None,
copy_first_teacher_layers=False,
e_layers_to_copy=None,
d_layers_to_copy=None,
**extra_config_kwargs
) -> Tuple[PreTrainedModel, List[int], List[int]]:
"""Make a student by copying alternating layers from a teacher, save it to save_path.
Args:
teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before
copying layers
save_path: where to save the student, defaults to student directory.
e: how many Encoder layers should the student have, default is fully copy of teacher
d: how many Decoder layers should the student have, default is fully copy of teacher
copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
**extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
Returns:
student: new, smaller model. (Also saves it to save_path)
e_layers_to_copy: list of which teacher encoder layers were used
d_layers_to_copy: list of which teacher decoder layers were used
"""
_msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher."
assert (e is not None) or (d is not None), _msg
if isinstance(teacher, str):
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
else:
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
init_kwargs = teacher.config.to_diff_dict()
try:
teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs)
# Copy weights
student_cfg = teacher.config_class(**init_kwargs)
student = AutoModelForSeq2SeqLM.from_config(student_cfg)
# Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
info = student.load_state_dict(teacher.state_dict(), strict=False)
assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy
# Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
if e_layers_to_copy is None:
e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
if d_layers_to_copy is None:
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.config.init_metadata = dict(
teacher_type=teacher.config.model_type,
copied_encoder_layers=e_layers_to_copy,
copied_decoder_layers=d_layers_to_copy,
)
student.save_pretrained(save_path)
# Save information about copying for easier reproducibility
return student, e_layers_to_copy, d_layers_to_copy
if __name__ == "__main__":
fire.Fire(create_student_by_copying_alternating_layers)

View File

@@ -0,0 +1,43 @@
### Saved Pseudo-Labels
These are the generations of various large models on various large **training** sets. All in all they took about 200 GPU hours to produce.
### Available Pseudo-labels
| Dataset | Model | Link | Rouge Scores | Notes
|---------|-----------------------------|----------------------------------------------------------------------------------------|--------------------|-------------------------------------------------------------------------------------------------------------
| XSUM | `facebook/bart-large-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz) | 49.8/28.0/42.5 |
| XSUM | `google/pegasus-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz) | 53.3/32.7/46.5 |
| XSUM | `facebook/bart-large-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/xsum_pl2_bart.tgz) | | Bart pseudolabels filtered to those with Rouge2 > 10.0 w GT.
| CNN/DM | `sshleifer/pegasus-cnn-ft-v2` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/pegasus_cnn_cnn_pls.tgz) | 47.316/26.65/44.56 | do not worry about the fact that train.source is one line shorter.
| CNN/DM | `facebook/bart-large-cnn` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/cnn_bart_pl.tgz) | | 5K (2%) are missing, there should be 282173
| CNN/DM | `google/pegasus-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/pegasus_xsum_on_cnn.tgz) | 21.5/6.76/25 | extra labels for xsum distillation Used max_source_length=512, (and all other pegasus-xsum configuration).
| EN-RO | `Helsinki-NLP/opus-mt-en-ro` | [download](https://cdn-datasets.huggingface.co/pseudo/wmt_en_ro/opus_mt_en_ro.tgz) | |
| EN-RO | `facebook/mbart-large-en-ro` | [download](https://cdn-datasets.huggingface.co/pseudo/wmt_en_ro/mbart_large_en_ro.tgz) | |
(EN_RO = WMT 2016 English-Romanian).
Example Download Command:
```bash
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
```
### Generating New Pseudolabels
Here is the command I used to generate the pseudolabels in the second row of the table, after downloading XSUM from [here](https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz).
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
--model_name google/pegasus-xsum \
--save_dir pegasus_xsum \
--data_dir xsum \
--bs 8 --sync_timeout 60000 \
--max_source_length 512 \
--type_path train
```
+ These commands takes a while to run. For example, `pegasus_cnn_cnn_pls.tgz` took 8 hours on 8 GPUs.
+ Pegasus does not work in fp16 :(, Bart, mBART and Marian do.
+ Even if you have 1 GPU, `run_distributed_eval.py` is 10-20% faster than `run_eval.py` because it uses `SortishSampler` to minimize padding computation.
### Contributions
Feel free to contribute your own pseudolabels via PR. Add a row to this table with a new google drive link (or other command line downloadable link).

View File

@@ -0,0 +1,20 @@
tensorboard
scikit-learn
psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==1.0.4
matplotlib
git-python==1.0.3
faiss-cpu
streamlit
elasticsearch
nltk
pandas
datasets >= 1.1.3
fire
pytest
conllu
sentencepiece != 0.1.92
protobuf

View File

@@ -0,0 +1,163 @@
#!/usr/bin/env python
import argparse
import datetime
import json
import time
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict, List
import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_bleu, calculate_rouge, chunks, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
logger = getLogger(__name__)
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def generate_summaries_or_translations(
examples: List[str],
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
prefix=None,
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
start_time = time.time()
# update config with task specific params
use_task_specific_params(model, task)
if prefix is None:
prefix = prefix or getattr(model.config, "prefix", "") or ""
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
examples_chunk = [prefix + text for text in examples_chunk]
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
summaries = model.generate(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
fout.close()
runtime = int(time.time() - start_time) # seconds
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
def datetime_now():
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def run_generate(verbose=True):
"""
Takes input text, generates output, and then using reference calculates the BLEU scores.
The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
Args:
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
Returns:
a tuple: ``(scores, params}``
- ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
- ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
"""
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("save_path", type=str, help="where to save summaries")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument(
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
)
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results")
parser.add_argument(
"--info",
nargs="?",
type=str,
const=datetime_now(),
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
if parsed_args and verbose:
print(f"parsed the following generate kwargs: {parsed_args}")
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
Path(args.save_path).parent.mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
runtime_metrics = generate_summaries_or_translations(
examples,
args.save_path,
args.model_name,
batch_size=args.bs,
device=args.device,
fp16=args.fp16,
task=args.task,
prefix=args.prefix,
**parsed_args,
)
if args.reference_path is None:
return {}
# Compute scores
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
scores.update(runtime_metrics)
if args.dump_args:
scores.update(parsed_args)
if args.info:
scores["info"] = args.info
if verbose:
print(scores)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w"))
return scores
if __name__ == "__main__":
# Usage for MT:
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
run_generate(verbose=True)

View File

@@ -0,0 +1,22 @@
import re
from filelock import FileLock
try:
import nltk
NLTK_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
NLTK_AVAILABLE = False
if NLTK_AVAILABLE:
with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)
def add_newline_to_end_of_each_sentence(x: str) -> str:
"""This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS."""
re.sub("<n>", "", x) # remove pegasus newline char
assert NLTK_AVAILABLE, "nltk must be installed to separate newlines between sentences. (pip install nltk)"
return "\n".join(nltk.sent_tokenize(x))

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export BS=32
export GAS=1
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.25 \
--n_val 500 \
--num_train_epochs 2 \
--freeze_encoder --freeze_embeds --data_dir cnn_dm \
--max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--model_name_or_path sshleifer/student_cnn_12_6 \
--tokenizer_name facebook/bart-large \
--warmup_steps 500 \
--output_dir distilbart-cnn-12-6 \
"$@"

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
python distillation.py \
--teacher facebook/bart-large-xsum --data_dir xsum \
--tokenizer_name facebook/bart-large-xsum \
--student_decoder_layers 6 --student_encoder_layers 12 \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--model_name_or_path IGNORED \
--alpha_hid=3. \
--train_batch_size=16 --eval_batch_size=16 --gradient_accumulation_steps=2 \
--sortish_sampler \
--num_train_epochs=6 \
--warmup_steps 500 \
--output_dir distilbart_xsum_12_6 \
"$@"

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--do_train \
--val_check_interval=0.25 \
--adam_eps 1e-06 \
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
--data_dir $ENRO_DIR \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--train_batch_size=$BS --eval_batch_size=$BS \
--task translation \
--warmup_steps 500 \
--freeze_embeds \
--model_name_or_path=facebook/mbart-large-cc25 \
"$@"

View File

@@ -0,0 +1,645 @@
import itertools
import json
import linecache
import math
import os
import pickle
import socket
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
import git
import numpy as np
import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try:
from fairseq.data.data_utils import batch_by_size
FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
FAIRSEQ_AVAILABLE = False
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class AbstractSeq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.len_file = Path(data_dir).joinpath(type_path + ".len")
if os.path.exists(self.len_file):
self.src_lens = pickle_load(self.len_file)
self.used_char_len = False
else:
self.src_lens = self.get_char_lens(self.src_file)
self.used_char_len = True
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix if prefix is not None else ""
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.dataset_kwargs = dataset_kwargs
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
def __len__(self):
return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@cached_property
def tgt_lens(self):
"""Length in characters of target documents"""
return self.get_char_lens(self.tgt_file)
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
def num_tokens_in_example(i):
return min(self.src_lens[i], self.max_target_length)
# call fairseq cython function
batch_sampler: List[List[int]] = batch_by_size(
sorted_indices,
num_tokens_fn=num_tokens_in_example,
max_tokens=max_tokens_per_batch,
required_batch_size_multiple=64,
)
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
largest_batch_idx = np.argmax(approximate_toks_per_batch)
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
shuffled_batches[largest_batch_idx],
shuffled_batches[0],
)
return shuffled_batches
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"labels": target_ids,
}
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**self.dataset_kwargs,
)
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"labels": y,
}
return batch
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
**self.dataset_kwargs,
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
if data_args.src_lang is not None:
self.dataset_kwargs["src_lang"] = data_args.src_lang
if data_args.tgt_lang is not None:
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
)
return batch_encoding.data
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i):
return data[i]
idxs = np.random.permutation(len(data))
sz = bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
sz = bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return sort_idx
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
if add_extra_examples:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
else:
self.total_size = len(dataset)
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
@cached_property
def available_indices(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
available_indices = indices[self.rank : self.total_size : self.num_replicas]
return available_indices
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
logger = getLogger(__name__)
def use_task_specific_params(model, task):
"""Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
pars = task_specific_params.get(task, {})
logger.info(f"using task specific params for {task}: {pars}")
model.config.update(pars)
def pickle_load(path):
"""pickle.load(path)"""
with open(path, "rb") as f:
return pickle.load(f)
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
try:
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
"hostname": str(socket.gethostname()),
}
return repo_infos
except TypeError:
return {
"repo_id": None,
"repo_sha": None,
"repo_branch": None,
"hostname": None,
}
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict
def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
Args:
pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores)
if bootstrap_aggregation:
result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
else:
return aggregator._scores # here we return defaultdict(list)
# Utilities for freezing parameters and checking whether they are frozen
def freeze_params(model: nn.Module):
"""Set requires_grad=False for each of model.parameters()"""
for par in model.parameters():
par.requires_grad = False
def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
elif model_type == "fsmt":
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters())
def any_requires_grad(model: nn.Module) -> bool:
return any(grad_status(model))
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Assumes all values are either numeric or boolean in the form of true/false.
"""
result = {}
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
num_pairs = len(unparsed_args) // 2
for pair_num in range(num_pairs):
i = 2 * pair_num
assert unparsed_args[i].startswith("--")
if unparsed_args[i + 1].lower() == "true":
value = True
elif unparsed_args[i + 1].lower() == "false":
value = False
else:
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
result[unparsed_args[i][2:]] = value
return result
def write_txt_file(ordered_tgt, path):
f = Path(path).open("w")
for ln in ordered_tgt:
f.write(ln + "\n")
f.flush()
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def check_output_dir(args, expected_items=0):
"""
Checks whether to bail out if output_dir already exists and has more than expected_items in it
`args`: needs to have the following attributes of `args`:
- output_dir
- do_train
- overwrite_output_dir
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
"""
if (
os.path.exists(args.output_dir)
and len(os.listdir(args.output_dir)) > expected_items
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({args.output_dir}) already exists and "
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)

View File

@@ -0,0 +1,645 @@
import itertools
import json
import linecache
import math
import os
import pickle
import socket
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Tuple, Union
import git
import numpy as np
import torch
import torch.distributed as dist
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.models.bart.modeling_bart import shift_tokens_right
try:
from fairseq.data.data_utils import batch_by_size
FAIRSEQ_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
FAIRSEQ_AVAILABLE = False
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
"""Uses sacrebleu's corpus_bleu implementation."""
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class AbstractSeq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
prefix="",
**dataset_kwargs
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.len_file = Path(data_dir).joinpath(type_path + ".len")
if os.path.exists(self.len_file):
self.src_lens = pickle_load(self.len_file)
self.used_char_len = False
else:
self.src_lens = self.get_char_lens(self.src_file)
self.used_char_len = True
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix if prefix is not None else ""
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.dataset_kwargs = dataset_kwargs
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
def __len__(self):
return len(self.src_lens)
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@cached_property
def tgt_lens(self):
"""Length in characters of target documents"""
return self.get_char_lens(self.tgt_file)
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
def num_tokens_in_example(i):
return min(self.src_lens[i], self.max_target_length)
# call fairseq cython function
batch_sampler: List[List[int]] = batch_by_size(
sorted_indices,
num_tokens_fn=num_tokens_in_example,
max_tokens=max_tokens_per_batch,
required_batch_size_multiple=64,
)
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
largest_batch_idx = np.argmax(approximate_toks_per_batch)
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
shuffled_batches[largest_batch_idx],
shuffled_batches[0],
)
return shuffled_batches
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
def collate_fn(self, batch):
raise NotImplementedError("You must implement this")
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
"""Call tokenizer on src and tgt_lines"""
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"labels": target_ids,
}
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**self.dataset_kwargs,
)
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["labels"] for x in batch])
pad_token_id = self.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"labels": y,
}
return batch
class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""A dataset that calls prepare_seq2seq_batch."""
def __getitem__(self, index) -> Dict[str, str]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
**self.dataset_kwargs,
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
if data_args.src_lang is not None:
self.dataset_kwargs["src_lang"] = data_args.src_lang
if data_args.tgt_lang is not None:
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
)
return batch_encoding.data
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i):
return data[i]
idxs = np.random.permutation(len(data))
sz = bs * 50
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
sz = bs
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
sort_idx = np.concatenate((ck_idx[0], sort_idx))
return sort_idx
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
if add_extra_examples:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
else:
self.total_size = len(dataset)
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
@cached_property
def available_indices(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
available_indices = indices[self.rank : self.total_size : self.num_replicas]
return available_indices
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
logger = getLogger(__name__)
def use_task_specific_params(model, task):
"""Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
pars = task_specific_params.get(task, {})
logger.info(f"using task specific params for {task}: {pars}")
model.config.update(pars)
def pickle_load(path):
"""pickle.load(path)"""
with open(path, "rb") as f:
return pickle.load(f)
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
try:
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
"hostname": str(socket.gethostname()),
}
return repo_infos
except TypeError:
return {
"repo_id": None,
"repo_sha": None,
"repo_branch": None,
"hostname": None,
}
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
def extract_rouge_mid_statistics(dct):
new_dict = {}
for k1, v1 in dct.items():
mid = v1.mid
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
return new_dict
def calculate_rouge(
pred_lns: List[str],
tgt_lns: List[str],
use_stemmer=True,
rouge_keys=ROUGE_KEYS,
return_precision_and_recall=False,
bootstrap_aggregation=True,
newline_sep=True,
) -> Dict:
"""Calculate rouge using rouge_scorer package.
Args:
pred_lns: list of summaries generated by model
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
use_stemmer: Bool indicating whether Porter stemmer should be used to
strip word suffixes to improve matching.
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
return_precision_and_recall: (False) whether to also return precision and recall.
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
on multi sentence summaries (CNN/DM dataset).
Returns:
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
"""
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for pred, tgt in zip(tgt_lns, pred_lns):
# rougeLsum expects "\n" separated sentences within a summary
if newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
scores = scorer.score(pred, tgt)
aggregator.add_scores(scores)
if bootstrap_aggregation:
result = aggregator.aggregate()
if return_precision_and_recall:
return extract_rouge_mid_statistics(result) # here we return dict
else:
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
else:
return aggregator._scores # here we return defaultdict(list)
# Utilities for freezing parameters and checking whether they are frozen
def freeze_params(model: nn.Module):
"""Set requires_grad=False for each of model.parameters()"""
for par in model.parameters():
par.requires_grad = False
def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
elif model_type == "fsmt":
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters())
def any_requires_grad(model: nn.Module) -> bool:
return any(grad_status(model))
def assert_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
n_require_grad = sum(lmap(int, model_grads))
npars = len(model_grads)
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
def assert_not_all_frozen(model):
model_grads: List[bool] = list(grad_status(model))
npars = len(model_grads)
assert any(model_grads), f"none of {npars} weights require grad"
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Assumes all values are either numeric or boolean in the form of true/false.
"""
result = {}
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
num_pairs = len(unparsed_args) // 2
for pair_num in range(num_pairs):
i = 2 * pair_num
assert unparsed_args[i].startswith("--")
if unparsed_args[i + 1].lower() == "true":
value = True
elif unparsed_args[i + 1].lower() == "false":
value = False
else:
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
result[unparsed_args[i][2:]] = value
return result
def write_txt_file(ordered_tgt, path):
f = Path(path).open("w")
for ln in ordered_tgt:
f.write(ln + "\n")
f.flush()
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def check_output_dir(args, expected_items=0):
"""
Checks whether to bail out if output_dir already exists and has more than expected_items in it
`args`: needs to have the following attributes of `args`:
- output_dir
- do_train
- overwrite_output_dir
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
"""
if (
os.path.exists(args.output_dir)
and len(os.listdir(args.output_dir)) > expected_items
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({args.output_dir}) already exists and "
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
"Use --overwrite_output_dir to overcome."
)