Reorganize examples (#9010)
* Reorganize example folder * Continue reorganization * Change requirements for tests * Final cleanup * Finish regroup with tests all passing * Copyright * Requirements and readme * Make a full link for the documentation * Address review comments * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add symlink * Reorg again * Apply suggestions from code review Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com> * Adapt title * Update to new strucutre * Remove test * Update READMEs Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,19 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
## Sequence to Sequence Training and Evaluation
|
||||
|
||||
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
|
||||
@@ -112,101 +128,6 @@ Datasets: `LegacySeq2SeqDataset` will be used for all tokenizers without a `prep
|
||||
Future work/help wanted: A new dataset to support multilingual tasks.
|
||||
|
||||
|
||||
### Finetuning Scripts
|
||||
All finetuning bash scripts call finetune.py (or distillation.py) with reasonable command line arguments. They usually require extra command line arguments to work.
|
||||
|
||||
To see all the possible command line options, run:
|
||||
|
||||
```bash
|
||||
./finetune.py --help
|
||||
```
|
||||
|
||||
### Finetuning Training Params
|
||||
|
||||
To override the pretrained model's training params, you can pass them to `./finetune.sh`:
|
||||
|
||||
```bash
|
||||
./finetune.sh \
|
||||
[...]
|
||||
--encoder_layerdrop 0.1 \
|
||||
--decoder_layerdrop 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
```
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
The following command should work on a 16GB GPU:
|
||||
```bash
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--train_batch_size=1 \
|
||||
--eval_batch_size=1 \
|
||||
--output_dir=xsum_results \
|
||||
--num_train_epochs 6 \
|
||||
--model_name_or_path facebook/bart-large
|
||||
```
|
||||
|
||||
There is a starter finetuning script for pegasus at `finetune_pegasus_xsum.sh`.
|
||||
|
||||
### Translation Finetuning
|
||||
|
||||
First, follow the wmt_en_ro download instructions.
|
||||
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||
|
||||
Best performing command:
|
||||
```bash
|
||||
# optionally
|
||||
export ENRO_DIR='wmt_en_ro' # Download instructions above
|
||||
# export WANDB_PROJECT="MT" # optional
|
||||
export MAX_LEN=128
|
||||
export BS=4
|
||||
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler
|
||||
```
|
||||
This should take < 6h/epoch on a 16GB v100 and achieve test BLEU above 26
|
||||
To get results in line with fairseq, you need to do some postprocessing. (see `romanian_postprocessing.md`)
|
||||
|
||||
MultiGPU command
|
||||
(using 8 GPUS as an example)
|
||||
```bash
|
||||
export ENRO_DIR='wmt_en_ro' # Download instructions above
|
||||
# export WANDB_PROJECT="MT" # optional
|
||||
export MAX_LEN=128
|
||||
export BS=4
|
||||
./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb
|
||||
```
|
||||
### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
|
||||
```bash
|
||||
output_dir
|
||||
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
|
||||
│ ├── config.json
|
||||
│ ├── merges.txt
|
||||
│ ├── pytorch_model.bin
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── git_log.json # repo, branch, and commit hash
|
||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT)
|
||||
├── metrics.json # new validation metrics will continually be appended to this
|
||||
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── test_generations.txt
|
||||
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
|
||||
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
|
||||
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
|
||||
```
|
||||
After training, you can recover the best checkpoint by running
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
### Fine-tuning using Seq2SeqTrainer
|
||||
To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer`-related `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument.
|
||||
|
||||
@@ -242,190 +163,6 @@ The following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-
|
||||
./builtin_trainer/train_distil_marian_enro_tpu.sh
|
||||
```
|
||||
|
||||
# DistilBART
|
||||
<!---It should be called distilling bart and pegasus, but I don't want to break the link in the paper.-->
|
||||
This section describes all code and artifacts from our [Paper](http://arxiv.org/abs/2010.13002)
|
||||
|
||||

|
||||
|
||||
+ For the CNN/DailyMail dataset, (relatively longer, more extractive summaries), we found a simple technique that works, which we call "Shrink and Fine-tune", or SFT.
|
||||
you just copy alternating layers from `facebook/bart-large-cnn` and fine-tune more on the cnn/dm data. `sshleifer/distill-pegasus-cnn-16-4`, `sshleifer/distilbart-cnn-12-6` and all other checkpoints under `sshleifer` that start with `distilbart-cnn` were trained this way.
|
||||
+ For the XSUM dataset, training on pseudo-labels worked best for Pegasus (`sshleifer/distill-pegasus-16-4`), while training with KD worked best for `distilbart-xsum-12-6`
|
||||
+ For `sshleifer/dbart-xsum-12-3`
|
||||
+ We ran 100s experiments, and didn't want to document 100s of commands. If you want a command to replicate a figure from the paper that is not documented below, feel free to ask on the [forums](https://discuss.huggingface.co/t/seq2seq-distillation-methodology-questions/1270) and tag `@sshleifer`.
|
||||
+ You can see the performance tradeoffs of model sizes [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=0).
|
||||
and more granular timing results [here](https://docs.google.com/spreadsheets/d/1EkhDMwVO02m8jCD1cG3RoFPLicpcL1GQHTQjfvDYgIM/edit#gid=1753259047&range=B2:I23).
|
||||
|
||||
### Evaluation
|
||||
|
||||
use [run_distributed_eval](./run_distributed_eval.py), with the following convenient alias
|
||||
```bash
|
||||
deval () {
|
||||
proc=$1
|
||||
m=$2
|
||||
dd=$3
|
||||
sd=$4
|
||||
shift
|
||||
shift
|
||||
shift
|
||||
shift
|
||||
python -m torch.distributed.launch --nproc_per_node=$proc run_distributed_eval.py \
|
||||
--model_name $m --save_dir $sd --data_dir $dd $@
|
||||
}
|
||||
```
|
||||
On a 1 GPU system, here are four commands (that assume `xsum`, `cnn_dm` are downloaded, cmd-F for those links in this file).
|
||||
|
||||
`distilBART`:
|
||||
```bash
|
||||
deval 1 sshleifer/distilbart-xsum-12-3 xsum dbart_12_3_xsum_eval --fp16 # --help for more choices.
|
||||
deval 1 sshleifer/distilbart-cnn_dm-12-6 cnn_dm dbart_12_6_cnn_eval --fp16
|
||||
```
|
||||
|
||||
`distill-pegasus`:
|
||||
```bash
|
||||
deval 1 sshleifer/distill-pegasus-cnn-16-4 cnn_dm dpx_cnn_eval
|
||||
deval 1 sshleifer/distill-pegasus-xsum-16-4 xsum dpx_xsum_eval
|
||||
```
|
||||
|
||||
### Distillation
|
||||
+ For all of the following commands, you can get roughly equivalent result and faster run times by passing `--num_beams=4`. That's not what we did for the paper.
|
||||
+ Besides the KD section, you can also run commands with the built-in transformers trainer. See, for example, [builtin_trainer/train_distilbart_cnn.sh](./builtin_trainer/train_distilbart_cnn.sh).
|
||||
+ Large performance deviations (> 5X slower or more than 0.5 Rouge-2 worse), should be reported.
|
||||
+ Multi-gpu (controlled with `--gpus` should work, but might require more epochs).
|
||||
|
||||
#### Recommended Workflow
|
||||
+ Get your dataset in the right format. (see 6 files above).
|
||||
+ Find a teacher model [Pegasus](https://huggingface.co/models?search=pegasus) (slower, better ROUGE) or `facebook/bart-large-xsum`/`facebook/bart-large-cnn` (faster, slightly lower.).
|
||||
Choose the checkpoint where the corresponding dataset is most similar (or identical to) your dataset.
|
||||
+ Follow the sections in order below. You can stop after SFT if you are satisfied, or move on to pseudo-labeling if you want more performance.
|
||||
+ student size: If you want a close to free 50% speedup, cut the decoder in half. If you want a larger speedup, cut it in 4.
|
||||
+ If your SFT run starts at a validation ROUGE-2 that is more than 10 pts below the teacher's validation ROUGE-2, you have a bug. Switching to a more expensive technique will not help. Try setting a breakpoint and looking at generation and truncation defaults/hyper-parameters, and share your experience on the forums!
|
||||
|
||||
|
||||
#### Initialization
|
||||
We use [make_student.py](./make_student.py) to copy alternating layers from the teacher, and save the resulting model to disk
|
||||
```bash
|
||||
python make_student.py facebook/bart-large-xsum --save_path dbart_xsum_12_3 -e 12 -d 3
|
||||
```
|
||||
or for `pegasus-xsum`
|
||||
```bash
|
||||
python make_student.py google/pegasus-xsum --save_path dpx_xsum_16_4 --e 16 --d 4
|
||||
```
|
||||
we now have an initialized student saved to `dbart_xsum_12_3`, which we will use for the following commands.
|
||||
+ Extension: To replicate more complicated initialize experiments in section 6.1, or try your own. Use the `create_student_by_copying_alternating_layers` function.
|
||||
|
||||
#### Pegasus
|
||||
+ The following commands are written for BART and will require, at minimum, the following modifications
|
||||
+ reduce batch size, and increase gradient accumulation steps so that the product `gpus * batch size * gradient_accumulation_steps = 256`. We used `--learning-rate` = 1e-4 * gradient accumulation steps.
|
||||
+ don't use fp16
|
||||
+ `--tokenizer_name google/pegasus-large`
|
||||
|
||||
### SFT (No Teacher Distillation)
|
||||
You don't need `distillation.py`, you can just run:
|
||||
|
||||
```bash
|
||||
python finetune.py \
|
||||
--data_dir xsum \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 --fp16_opt_level=O1 \
|
||||
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--model_name_or_path dbart_xsum_12_3 \
|
||||
--train_batch_size=64 --eval_batch_size=64 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs=6 \
|
||||
--warmup_steps 500 \
|
||||
--output_dir distilbart_xsum_sft_12_3 --gpus 1
|
||||
```
|
||||
|
||||
+ Note: The command that produced `sshleifer/distilbart-cnn-12-6` is at [train_distilbart_cnn.sh](./[train_distilbart_cnn.sh)
|
||||
|
||||
```bash
|
||||
./train_distilbart_cnn.sh
|
||||
```
|
||||
<!--- runtime: 6H on NVIDIA RTX 24GB GPU -->
|
||||
+ Tip: You can get the same simple distillation logic by using `distillation.py --no_teacher ` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
|
||||
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
||||
because you will have the same hyper-parameters logged in every run.
|
||||
|
||||
### Pseudo-Labeling
|
||||
+ You don't need `distillation.py`.
|
||||
+ Instructions to generate pseudo-labels and use pre-computed pseudo-labels can be found [here](./precomputed_pseudo_labels.md).
|
||||
Simply run `finetune.py` with one of those pseudo-label datasets as `--data_dir` (`DATA`, below).
|
||||
|
||||
```bash
|
||||
python finetune.py \
|
||||
--teacher facebook/bart-large-xsum --data_dir DATA \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 --fp16_opt_level=O1 \
|
||||
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--model_name_or_path dbart_xsum_12_3 \
|
||||
--train_batch_size=32 --eval_batch_size=32 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs=5 \
|
||||
--warmup_steps 500 \
|
||||
--output_dir dbart_xsum_12_3_PL --gpus 1 --logger_name wandb
|
||||
```
|
||||
|
||||
|
||||
|
||||
To combine datasets, as in Section 6.2, try something like:
|
||||
```bash
|
||||
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
|
||||
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz | tar -xvz -C .
|
||||
curl -S https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz | tar -xvz -C .
|
||||
mkdir all_pl
|
||||
cat bart_xsum_pl/train.source pegasus_xsum/train.source xsum/train.source > all_pl/train.source
|
||||
cat bart_xsum_pl/train.target pegasus_xsum/train.target xsum/train.target > all_pl/train.target
|
||||
cp xsum/val* all_pl
|
||||
cp xsum/test* all_pl
|
||||
```
|
||||
then use `all_pl` as DATA in the command above.
|
||||
|
||||
#### Direct Knowledge Distillation (KD)
|
||||
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `SummarizationDistiller`.
|
||||
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
|
||||
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.
|
||||
|
||||
The command that produced `sshleifer/distilbart-xsum-12-6` is at [./train_distilbart_xsum.sh](train_distilbart_xsum.sh)
|
||||
```bash
|
||||
./train_distilbart_xsum.sh --logger_name wandb --gpus 1
|
||||
```
|
||||
|
||||
+ Expected ROUGE-2 between 21.3 and 21.6, run time ~13H.
|
||||
+ direct KD + Pegasus is VERY slow and works best with `--supervise_forward --normalize_hidden`.
|
||||
|
||||
<!--- runtime: 13H on V-100 16GB GPU. -->
|
||||
|
||||
### Citation
|
||||
|
||||
```bibtex
|
||||
@misc{shleifer2020pretrained,
|
||||
title={Pre-trained Summarization Distillation},
|
||||
author={Sam Shleifer and Alexander M. Rush},
|
||||
year={2020},
|
||||
eprint={2010.13002},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
@article{Wolf2019HuggingFacesTS,
|
||||
title={HuggingFace's Transformers: State-of-the-art Natural Language Processing},
|
||||
author={Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush},
|
||||
journal={ArXiv},
|
||||
year={2019},
|
||||
volume={abs/1910.03771}
|
||||
}
|
||||
```
|
||||
|
||||
This is the end of the distillation section, the rest of this doc pertains to general seq2seq commands.
|
||||
|
||||
## Evaluation Commands
|
||||
|
||||
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# Text Summarization with Pretrained Encoders
|
||||
|
||||
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
|
||||
|
||||
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
|
||||
|
||||
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
|
||||
|
||||
## Setup
|
||||
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk py-rouge
|
||||
cd examples/seq2seq/bertabs
|
||||
```
|
||||
|
||||
## Reproduce the authors' ROUGE score
|
||||
|
||||
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
|
||||
|
||||
```bash
|
||||
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
|
||||
```
|
||||
|
||||
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
--compute_rouge true
|
||||
```
|
||||
|
||||
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not supported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
|
||||
|
||||
## Summarize any text
|
||||
|
||||
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
|
||||
|
||||
```bash
|
||||
python run_summarization.py \
|
||||
--documents_dir $DATA_PATH \
|
||||
--summaries_output_dir $SUMMARIES_PATH \ # optional
|
||||
--no_cuda false \
|
||||
--batch_size 4 \
|
||||
--min_length 50 \
|
||||
--max_length 200 \
|
||||
--beam_size 5 \
|
||||
--alpha 0.95 \
|
||||
--block_trigram true \
|
||||
```
|
||||
|
||||
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
|
||||
@@ -1,97 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BertAbs configuration """
|
||||
import logging
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r"""Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
Number of tokens in the vocabulary.
|
||||
max_pos: int
|
||||
The maximum sequence length that this model will be used with.
|
||||
enc_layer: int
|
||||
The numner of hidden layers in the Transformer encoder.
|
||||
enc_hidden_size: int
|
||||
The size of the encoder's layers.
|
||||
enc_heads: int
|
||||
The number of attention heads for each attention layer in the encoder.
|
||||
enc_ff_size: int
|
||||
The size of the encoder's feed-forward layers.
|
||||
enc_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the encoder.
|
||||
dec_layer: int
|
||||
The numner of hidden layers in the decoder.
|
||||
dec_hidden_size: int
|
||||
The size of the decoder's layers.
|
||||
dec_heads: int
|
||||
The number of attention heads for each attention layer in the decoder.
|
||||
dec_ff_size: int
|
||||
The size of the decoder's feed-forward layers.
|
||||
dec_dropout: int
|
||||
The dropout probability for all fully connected layers in the
|
||||
embeddings, layers, pooler and also the attention probabilities in
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_pos = max_pos
|
||||
|
||||
self.enc_layers = enc_layers
|
||||
self.enc_hidden_size = enc_hidden_size
|
||||
self.enc_heads = enc_heads
|
||||
self.enc_ff_size = enc_ff_size
|
||||
self.enc_dropout = enc_dropout
|
||||
|
||||
self.dec_layers = dec_layers
|
||||
self.dec_hidden_size = dec_hidden_size
|
||||
self.dec_heads = dec_heads
|
||||
self.dec_ff_size = dec_ff_size
|
||||
self.dec_dropout = dec_dropout
|
||||
@@ -1,185 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Convert BertExtAbs's checkpoints.
|
||||
|
||||
The script looks like it is doing something trivial but it is not. The "weights"
|
||||
proposed by the authors are actually the entire model pickled. We need to load
|
||||
the model within the original codebase to be able to only save its `state_dict`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from model_bertabs import BertAbsSummarizer
|
||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
"""Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
# Instantiate the authors' model with the pre-trained weights
|
||||
config = BertAbsConfig(
|
||||
temp_dir=".",
|
||||
finetune_bert=False,
|
||||
large=False,
|
||||
share_emb=True,
|
||||
use_bert_emb=False,
|
||||
encoder="bert",
|
||||
max_pos=512,
|
||||
enc_layers=6,
|
||||
enc_hidden_size=512,
|
||||
enc_heads=8,
|
||||
enc_ff_size=512,
|
||||
enc_dropout=0.2,
|
||||
dec_layers=6,
|
||||
dec_hidden_size=768,
|
||||
dec_heads=8,
|
||||
dec_ff_size=2048,
|
||||
dec_dropout=0.2,
|
||||
)
|
||||
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
|
||||
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
|
||||
original.eval()
|
||||
|
||||
new_model = BertAbsSummarizer(config, torch.device("cpu"))
|
||||
new_model.eval()
|
||||
|
||||
# -------------------
|
||||
# Convert the weights
|
||||
# -------------------
|
||||
|
||||
logging.info("convert the model")
|
||||
new_model.bert.load_state_dict(original.bert.state_dict())
|
||||
new_model.decoder.load_state_dict(original.decoder.state_dict())
|
||||
new_model.generator.load_state_dict(original.generator.state_dict())
|
||||
|
||||
# ----------------------------------
|
||||
# Make sure the outpus are identical
|
||||
# ----------------------------------
|
||||
|
||||
logging.info("Make sure that the models' outputs are identical")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
# prepare the model inputs
|
||||
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
|
||||
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
|
||||
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
|
||||
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
|
||||
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
|
||||
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
|
||||
|
||||
# failsafe to make sure the weights reset does not affect the
|
||||
# loaded weights.
|
||||
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
|
||||
|
||||
# forward pass
|
||||
src = encoder_input_ids
|
||||
tgt = decoder_input_ids
|
||||
segs = token_type_ids = None
|
||||
clss = None
|
||||
mask_src = encoder_attention_mask = None
|
||||
mask_tgt = decoder_attention_mask = None
|
||||
mask_cls = None
|
||||
|
||||
# The original model does not apply the geneator layer immediatly but rather in
|
||||
# the beam search (where it combines softmax + linear layer). Since we already
|
||||
# apply the softmax in our generation process we only apply the linear layer here.
|
||||
# We make sure that the outputs of the full stack are identical
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
|
||||
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
|
||||
|
||||
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
|
||||
if are_identical:
|
||||
logging.info("all weights are equal up to 1e-3")
|
||||
else:
|
||||
raise ValueError("the weights are different. The new model is likely different from the original one.")
|
||||
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(
|
||||
new_model.state_dict(), "./bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +0,0 @@
|
||||
transformers
|
||||
|
||||
# For ROUGE
|
||||
nltk
|
||||
py-rouge
|
||||
@@ -1,347 +0,0 @@
|
||||
#! /usr/bin/python3
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from modeling_bertabs import BertAbs, build_predictor
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .utils_summarization import (
|
||||
CNNDMDataset,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
encode_for_summarization,
|
||||
truncate_or_pad,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||
model = BertAbs.from_pretrained("remi/bertabs-finetuned-extractive-abstractive-summarization")
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
|
||||
symbols = {
|
||||
"BOS": tokenizer.vocab["[unused0]"],
|
||||
"EOS": tokenizer.vocab["[unused1]"],
|
||||
"PAD": tokenizer.vocab["[PAD]"],
|
||||
}
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
weight_factor=1.2,
|
||||
stemming=True,
|
||||
)
|
||||
|
||||
# these (unused) arguments are defined to keep the compatibility
|
||||
# with the legacy code and will be deleted in a next iteration.
|
||||
args.result_path = ""
|
||||
args.temp_dir = ""
|
||||
|
||||
data_iterator = build_data_iterator(args, tokenizer)
|
||||
predictor = build_predictor(args, tokenizer, symbols, model)
|
||||
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Number examples = %d", len(data_iterator.dataset))
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info("")
|
||||
logger.info("***** Beam Search parameters *****")
|
||||
logger.info(" Beam size = %d", args.beam_size)
|
||||
logger.info(" Minimum length = %d", args.min_length)
|
||||
logger.info(" Maximum length = %d", args.max_length)
|
||||
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
|
||||
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
|
||||
|
||||
for batch in tqdm(data_iterator):
|
||||
batch_data = predictor.translate_batch(batch)
|
||||
translations = predictor.from_batch(batch_data)
|
||||
summaries = [format_summary(t) for t in translations]
|
||||
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
|
||||
|
||||
if args.compute_rouge:
|
||||
reference_summaries += batch.tgt_str
|
||||
generated_summaries += summaries
|
||||
|
||||
if args.compute_rouge:
|
||||
scores = rouge_evaluator.get_scores(generated_summaries, reference_summaries)
|
||||
str_scores = format_rouge_scores(scores)
|
||||
save_rouge_scores(str_scores)
|
||||
print(str_scores)
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
"""Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
original_document_names: List[string]
|
||||
Name of the document that was summarized.
|
||||
path: string
|
||||
Path were the summaries will be written
|
||||
summaries: List[string]
|
||||
The summaries that we produced.
|
||||
"""
|
||||
for summary, document_name in zip(summaries, original_document_name):
|
||||
# Prepare the summary file's name
|
||||
if "." in document_name:
|
||||
bare_document_name = ".".join(document_name.split(".")[:-1])
|
||||
extension = document_name.split(".")[-1]
|
||||
name = bare_document_name + "_summary." + extension
|
||||
else:
|
||||
name = document_name + "_summary"
|
||||
|
||||
file_path = os.path.join(path, name)
|
||||
with open(file_path, "w") as output:
|
||||
output.write(summary)
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
"""Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
summary = (
|
||||
raw_summary.replace("[unused0]", "")
|
||||
.replace("[unused3]", "")
|
||||
.replace("[PAD]", "")
|
||||
.replace("[unused1]", "")
|
||||
.replace(r" +", " ")
|
||||
.replace(" [unused2] ", ". ")
|
||||
.replace("[unused2]", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def format_rouge_scores(scores):
|
||||
return """\n
|
||||
****** ROUGE SCORES ******
|
||||
|
||||
** ROUGE 1
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE 2
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}
|
||||
|
||||
** ROUGE L
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
def save_rouge_scores(str_scores):
|
||||
with open("rouge_scores.txt", "w") as output:
|
||||
output.write(str_scores)
|
||||
|
||||
|
||||
#
|
||||
# LOAD the dataset
|
||||
#
|
||||
|
||||
|
||||
def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
def collate_fn(data):
|
||||
return collate(data, tokenizer, block_size=512, device=args.device)
|
||||
|
||||
iterator = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
return iterator
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = CNNDMDataset(args.documents_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
"""Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
API.
|
||||
"""
|
||||
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[truncate_or_pad(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
|
||||
batch = Batch(
|
||||
document_names=names,
|
||||
batch_size=len(encoded_stories),
|
||||
src=encoded_stories.to(device),
|
||||
segs=encoder_token_type_ids.to(device),
|
||||
mask_src=encoder_mask.to(device),
|
||||
tgt_str=summaries,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
"""Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
summary = tokenizer.decode(summary_tokens)
|
||||
sentences = summary.split(".")
|
||||
sentences = [s + "." for s in sentences]
|
||||
return sentences
|
||||
|
||||
|
||||
def main():
|
||||
"""The main function defines the interface with the users."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The folder where the documents to summarize are located.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summaries_output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compute_rouge",
|
||||
default=False,
|
||||
type=bool,
|
||||
required=False,
|
||||
help="Compute the ROUGE metrics during evaluation. Only available for the CNN/DailyMail dataset.",
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Select device (distibuted not available)
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
# Check the existence of directories
|
||||
if not args.summaries_output_dir:
|
||||
args.summaries_output_dir = args.documents_dir
|
||||
|
||||
if not documents_dir_is_valid(args.documents_dir):
|
||||
raise FileNotFoundError(
|
||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||
)
|
||||
os.makedirs(args.summaries_output_dir, exist_ok=True)
|
||||
|
||||
evaluate(args)
|
||||
|
||||
|
||||
def documents_dir_is_valid(path):
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
|
||||
file_list = os.listdir(path)
|
||||
if len(file_list) == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,98 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils_summarization import build_mask, compute_token_type_ids, process_story, truncate_or_pad
|
||||
|
||||
|
||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
"""Processing a story with no highlights returns an empty list for the summary."""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
"""An empty story returns an empty collection of lines."""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 0).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(build_mask(sequence, 1).numpy(), expected.numpy())
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
@@ -1,167 +0,0 @@
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Data loading
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
"""Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
formatted. On the CNN/DailyMail dataset it will extract both the story
|
||||
and the summary.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
"""We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
|
||||
self.documents = []
|
||||
story_filenames_list = os.listdir(path)
|
||||
for story_filename in story_filenames_list:
|
||||
if "summary" in story_filename:
|
||||
continue
|
||||
path_to_story = os.path.join(path, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
""" Returns the number of documents. """
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
document_path = self.documents[idx]
|
||||
document_name = document_path.split("/")[-1]
|
||||
with open(document_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return document_name, story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
"""Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the story is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None, raising an exception.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', "\u2019", "\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Encoding and preprocessing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
"""Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
"""Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1."""
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
"""Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
"""Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = -1
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
@@ -1,10 +0,0 @@
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./builtin_trainer/finetune.sh --help to see all the possible options
|
||||
python finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--n_val 1000 \
|
||||
"$@"
|
||||
@@ -1,12 +0,0 @@
|
||||
export TPU_NUM_CORES=8
|
||||
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./builtin_trainer/finetune_tpu.sh --help to see all the possible options
|
||||
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--do_train --do_eval \
|
||||
--evaluation_strategy steps \
|
||||
--prediction_loss_only \
|
||||
--n_val 1000 \
|
||||
"$@"
|
||||
@@ -1,25 +0,0 @@
|
||||
export WANDB_PROJECT=distilbart-trainer
|
||||
export BS=32
|
||||
export m=sshleifer/student_cnn_12_6
|
||||
export tok=facebook/bart-large
|
||||
export MAX_TGT_LEN=142
|
||||
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path $m --tokenizer_name $tok \
|
||||
--data_dir cnn_dm \
|
||||
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 --sortish_sampler \
|
||||
--fp16 \
|
||||
--n_val 500 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=2 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--logging_first_step \
|
||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --sortish_sampler \
|
||||
"$@"
|
||||
@@ -1,22 +0,0 @@
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||
--data_dir $ENRO_DIR \
|
||||
--output_dir mbart_cc25_enro --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 \
|
||||
--fp16 \
|
||||
--label_smoothing 0.1 \
|
||||
--adam_eps 1e-06 \
|
||||
--src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_embeds \
|
||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||
--max_source_length 128 --max_target_length 128 \
|
||||
--val_max_target_length 128 --test_max_target_length 128 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs 6 \
|
||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --logging_first_step \
|
||||
--task translation \
|
||||
"$@"
|
||||
@@ -1,115 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
# Uncommenting this will save val generations
|
||||
# return self._write_logs(trainer, pl_module, "valid")
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
||||
"""Saves the best model by validation ROUGE2 score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "loss":
|
||||
exp = "{val_avg_loss:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=output_dir,
|
||||
filename=exp,
|
||||
monitor=f"val_{metric}",
|
||||
mode="min" if "loss" in metric else "max",
|
||||
save_top_k=save_top_k,
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(
|
||||
monitor=f"val_{metric}", # does this need avg?
|
||||
mode="min" if "loss" in metric else "max",
|
||||
patience=patience,
|
||||
verbose=True,
|
||||
)
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from transformers.utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :]
|
||||
return text # or whatever
|
||||
|
||||
|
||||
def sanitize(sd):
|
||||
return {remove_prefix(k, "model."): v for k, v in sd.items()}
|
||||
|
||||
|
||||
def average_state_dicts(state_dicts: List[Dict[str, torch.Tensor]]):
|
||||
new_sd = {}
|
||||
for k in state_dicts[0].keys():
|
||||
tensors = [sd[k] for sd in state_dicts]
|
||||
new_t = sum(tensors) / len(tensors)
|
||||
assert isinstance(new_t, torch.Tensor)
|
||||
new_sd[k] = new_t
|
||||
return new_sd
|
||||
|
||||
|
||||
def convert_pl_to_hf(pl_ckpt_path: str, hf_src_model_dir: str, save_path: str) -> None:
|
||||
"""Cleanup a pytorch-lightning .ckpt file or experiment dir and save a huggingface model with that state dict.
|
||||
Silently allows extra pl keys (like teacher.) Puts all ckpt models into CPU RAM at once!
|
||||
|
||||
Args:
|
||||
pl_ckpt_path (:obj:`str`): Path to a .ckpt file saved by pytorch_lightning or dir containing ckpt files.
|
||||
If a directory is passed, all .ckpt files inside it will be averaged!
|
||||
hf_src_model_dir (:obj:`str`): Path to a directory containing a correctly shaped checkpoint
|
||||
save_path (:obj:`str`): Directory to save the new model
|
||||
|
||||
"""
|
||||
hf_model = AutoModelForSeq2SeqLM.from_pretrained(hf_src_model_dir)
|
||||
if os.path.isfile(pl_ckpt_path):
|
||||
ckpt_files = [pl_ckpt_path]
|
||||
else:
|
||||
assert os.path.isdir(pl_ckpt_path)
|
||||
ckpt_files = list(Path(pl_ckpt_path).glob("*.ckpt"))
|
||||
assert ckpt_files, f"could not find any ckpt files inside the {pl_ckpt_path} directory"
|
||||
|
||||
if len(ckpt_files) > 1:
|
||||
logger.info(f"averaging the weights of {ckpt_files}")
|
||||
|
||||
state_dicts = [sanitize(torch.load(x, map_location="cpu")["state_dict"]) for x in ckpt_files]
|
||||
state_dict = average_state_dicts(state_dicts)
|
||||
|
||||
missing, unexpected = hf_model.load_state_dict(state_dict, strict=False)
|
||||
assert not missing, f"missing keys: {missing}"
|
||||
hf_model.save_pretrained(save_path)
|
||||
try:
|
||||
tok = AutoTokenizer.from_pretrained(hf_src_model_dir)
|
||||
tok.save_pretrained(save_path)
|
||||
except Exception:
|
||||
pass
|
||||
# dont copy tokenizer if cant
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(convert_pl_to_hf)
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export WANDB_PROJECT=dmar
|
||||
# export MAX_LEN=128
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--fp16 \
|
||||
--val_check_interval 0.25 \
|
||||
--teacher Helsinki-NLP/opus-mt-en-ro \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--student_decoder_layers 3 --student_encoder_layers 6 \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--model_name_or_path IGNORED \
|
||||
--alpha_hid=3. \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
|
||||
--warmup_steps 500 --logger_name wandb \
|
||||
--fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \
|
||||
"$@"
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export WANDB_PROJECT=dmar
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 --no_teacher \
|
||||
--val_check_interval 0.25 \
|
||||
--data_dir $ENRO_DIR \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||
--tokenizer_name $m --model_name_or_path $m \
|
||||
--warmup_steps 500 --sortish_sampler --logger_name wandb \
|
||||
--gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \
|
||||
"$@"
|
||||
@@ -1,310 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import generic_train # noqa
|
||||
|
||||
|
||||
class SummarizationDistiller(SummarizationModule):
|
||||
"""Supports T5, Bart, Pegasus and other models that inherit from Bart."""
|
||||
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
assert Path(hparams.data_dir).exists()
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
|
||||
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
|
||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
||||
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
|
||||
if hparams.student is not None:
|
||||
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
|
||||
use_task_specific_params(student, hparams.task)
|
||||
e_layer_ids, d_layer_ids = None, None
|
||||
else:
|
||||
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
|
||||
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
|
||||
)
|
||||
|
||||
if hparams.length_penalty != -1:
|
||||
student.config.length_penalty = hparams.length_penalty
|
||||
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
|
||||
super().__init__(hparams, model=student, config=student.config)
|
||||
assert (
|
||||
student.config.model_type == teacher.config.model_type
|
||||
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
|
||||
|
||||
if student.config.model_type == "t5":
|
||||
student_encoder_layers = len(student.get_encoder().block)
|
||||
student_decoder_layers = len(student.get_decoder().block)
|
||||
teacher_encoder_layers = len(teacher.get_encoder().block)
|
||||
teacher_decoder_layers = len(teacher.get_decoder().block)
|
||||
else:
|
||||
student_encoder_layers = student.config.encoder_layers
|
||||
student_decoder_layers = student.config.decoder_layers
|
||||
teacher_encoder_layers = teacher.config.encoder_layers
|
||||
teacher_decoder_layers = teacher.config.decoder_layers
|
||||
|
||||
self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student)
|
||||
self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0
|
||||
self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers)
|
||||
# self.different_encoder determines whether we need to run the teacher encoder
|
||||
self.teacher = teacher
|
||||
freeze_params(self.teacher)
|
||||
|
||||
if not self.different_encoder: # To save RAM, delete teacher encoder and freeze student encoder.
|
||||
try:
|
||||
del self.teacher.model.encoder
|
||||
except AttributeError: # T5
|
||||
del self.teacher.encoder
|
||||
|
||||
if e_layer_ids is None:
|
||||
e_layer_ids = list(range(student_encoder_layers))
|
||||
if d_layer_ids is None:
|
||||
d_layer_ids = list(range(student_decoder_layers))
|
||||
|
||||
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
|
||||
|
||||
if self.do_calc_hidden_loss: # Intermediate supervision: Decide which layers to supervise
|
||||
if hparams.supervise_forward:
|
||||
self.e_matches = get_layers_to_supervise(
|
||||
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
|
||||
)
|
||||
self.d_matches = get_layers_to_supervise(
|
||||
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
|
||||
)
|
||||
else: # student layer should emulate hidden states of the teacher layer it was copied from
|
||||
self.e_matches = self.e_layer_ids
|
||||
self.d_matches = self.d_layer_ids
|
||||
else:
|
||||
self.e_matches = None
|
||||
self.d_matches = None
|
||||
|
||||
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
self.temperature = 2.0
|
||||
self.alpha_mlm = hparams.alpha_mlm
|
||||
self.alpha_ce = hparams.alpha_ce
|
||||
self.alpha_hid = hparams.alpha_hid
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def calc_ce_loss(self, mask, s_logits, t_logits):
|
||||
"""Copy pasted from distillbert (transformers/examples/distillation/)"""
|
||||
# mask has False at padding_idx
|
||||
sel_mask = mask[:, :, None].expand_as(s_logits)
|
||||
vocab_size = s_logits.size(-1)
|
||||
s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
|
||||
s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
return loss_ce
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
SummarizationModule.add_model_specific_args(parser, root_dir)
|
||||
add_distill_args(parser)
|
||||
return parser
|
||||
|
||||
def _step(self, batch: dict) -> tuple:
|
||||
"""Compute the loss for a batch"""
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
|
||||
|
||||
# noinspection PyCallingNonCallable
|
||||
student_outputs = self(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
lm_logits = student_outputs["logits"]
|
||||
|
||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = F.log_softmax(lm_logits, dim=-1)
|
||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
|
||||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||
|
||||
teacher_enc_outputs = student_outputs[
|
||||
"encoder_last_hidden_state"
|
||||
] # use this unless self.different_base_models
|
||||
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||
if self.different_encoder: # compute encoder hidden state loss
|
||||
all_teacher_encoder_outputs = self.teacher.get_encoder()(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
)
|
||||
if self.different_base_models:
|
||||
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
|
||||
elif self.do_calc_hidden_loss:
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
src_mask,
|
||||
student_outputs["encoder_hidden_states"],
|
||||
all_teacher_encoder_outputs["hidden_states"],
|
||||
self.e_matches,
|
||||
normalize_hidden=self.hparams.normalize_hidden,
|
||||
)
|
||||
|
||||
teacher_outputs = self.teacher(
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
encoder_outputs=(teacher_enc_outputs,),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
use_cache=False, # since we are not passing labels, never let this default to True
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
|
||||
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
|
||||
hid_loss_dec = self.calc_hidden_loss(
|
||||
dec_mask,
|
||||
student_outputs["decoder_hidden_states"],
|
||||
teacher_outputs["decoder_hidden_states"],
|
||||
self.d_matches,
|
||||
normalize_hidden=self.hparams.normalize_hidden,
|
||||
)
|
||||
|
||||
blended_loss = (
|
||||
self.alpha_ce * loss_ce
|
||||
+ self.alpha_mlm * student_lm_loss
|
||||
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
|
||||
)
|
||||
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
|
||||
|
||||
@staticmethod
|
||||
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
|
||||
"""MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
|
||||
msg = "expected list or tuple for hidden_states, got tensor of shape: "
|
||||
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
|
||||
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
|
||||
mask = attention_mask.to(hidden_states[0])
|
||||
valid_count = mask.sum() * hidden_states[0].size(-1)
|
||||
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
|
||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
|
||||
if normalize_hidden:
|
||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
||||
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
||||
return masked_mse
|
||||
|
||||
|
||||
def add_distill_args(parser):
|
||||
# NOTE: if --student argument was specified and the teacher and student base models
|
||||
# are different, the models still have to have the same tokenizer, specified by
|
||||
# --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not
|
||||
# from bart to t5. This s because if the tokenizers are different, the output space
|
||||
# for the two models is also different and their logits are not comparable.
|
||||
parser.add_argument("--teacher", type=str)
|
||||
parser.add_argument("--alpha_ce", default=0.8, type=float)
|
||||
parser.add_argument("--alpha_mlm", default=0.2, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument("--student", type=str, required=False)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
parser.add_argument("--supervise_forward", action="store_true", default=False)
|
||||
parser.add_argument("--normalize_hidden", action="store_true", default=False)
|
||||
|
||||
|
||||
class TranslationDistiller(SummarizationDistiller):
|
||||
"""Supports T5, mBART, Marian, other models that inherit from Bart."""
|
||||
|
||||
mode = "translation"
|
||||
metric_names = ["bleu"]
|
||||
default_val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
assert hparams.src_lang is not None
|
||||
assert hparams.tgt_lang is not None
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu(preds, target)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
TranslationModule.add_model_specific_args(parser, root_dir)
|
||||
add_distill_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def create_module(args):
|
||||
if args.no_teacher:
|
||||
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
|
||||
else: # DISTILL WITH TEACHER
|
||||
module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller
|
||||
args.setup_cls: str = module_cls.__name__
|
||||
print(f"using module {args.setup_cls}")
|
||||
model = module_cls(args)
|
||||
return model
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
check_output_dir(args, expected_items=3)
|
||||
|
||||
model = create_module(args)
|
||||
return ft_main(args, model=model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
distill_main(args)
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export WANDB_PROJECT=dmar
|
||||
export MAX_LEN=128
|
||||
export m=sshleifer/student_marian_en_ro_6_1
|
||||
python finetune.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--fp16 \
|
||||
--data_dir wmt_en_ro \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--train_batch_size=48 --eval_batch_size=64 \
|
||||
--tokenizer_name $m --model_name_or_path $m --num_train_epochs=1 \
|
||||
--warmup_steps 500 --logger_name wandb --gpus 1 \
|
||||
--fp16_opt_level=O1 --task translation \
|
||||
"$@"
|
||||
@@ -1,442 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
ROUGE_KEYS,
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
check_output_dir,
|
||||
flatten_list,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummarizationModule(BaseTransformer):
|
||||
mode = "summarization"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ROUGE_KEYS
|
||||
default_val_metric = "rouge2"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
if hparams.sortish_sampler and hparams.gpus > 1:
|
||||
hparams.replace_sampler_ddp = False
|
||||
elif hparams.max_tokens_per_batch is not None:
|
||||
if hparams.gpus > 1:
|
||||
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
|
||||
if hparams.sortish_sampler:
|
||||
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
|
||||
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
self.model_type = self.config.model_type
|
||||
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=self.model.config.prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
if self.hparams.freeze_embeds:
|
||||
freeze_embeds(self.model)
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None # default to config
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
||||
self.dataset_class = (
|
||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
)
|
||||
self.already_saved_batch = False
|
||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||
if self.hparams.eval_max_gen_length is not None:
|
||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||
else:
|
||||
self.eval_max_length = self.model.config.max_length
|
||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||
|
||||
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
||||
"""A debugging utility"""
|
||||
readable_batch = {
|
||||
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
|
||||
}
|
||||
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
|
||||
save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
|
||||
|
||||
self.already_saved_batch = True
|
||||
return readable_batch
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
|
||||
tgt_ids = batch["labels"]
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(tgt_ids)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
|
||||
batch["decoder_input_ids"] = decoder_input_ids
|
||||
self.save_readable_batch(batch)
|
||||
|
||||
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
lm_logits = outputs["logits"]
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
|
||||
assert lm_logits.shape[-1] == self.vocab_size
|
||||
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
return self.tokenizer.pad_token_id
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
||||
logs["bs"] = batch["input_ids"].shape[0]
|
||||
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
|
||||
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
|
||||
# TODO(SS): make a wandb summary metric for this
|
||||
return {"loss": loss_tensors[0], "log": logs}
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
generative_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metric_val = (
|
||||
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
|
||||
)
|
||||
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
||||
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
losses.update(generative_metrics)
|
||||
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
all_metrics["step_count"] = self.step_count
|
||||
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {
|
||||
"log": all_metrics,
|
||||
"preds": preds,
|
||||
f"{prefix}_loss": loss,
|
||||
f"{prefix}_{self.val_metric}": metric_tensor,
|
||||
}
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_rouge(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
t0 = time.time()
|
||||
|
||||
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
use_cache=True,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
num_beams=self.eval_beams,
|
||||
max_length=self.eval_max_length,
|
||||
)
|
||||
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["labels"])
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = self.dataset_class(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
|
||||
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
|
||||
batch_sampler = dataset.make_dynamic_sampler(
|
||||
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
# shuffle=False,
|
||||
num_workers=self.num_workers,
|
||||
# batch_size=None,
|
||||
)
|
||||
else:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
sampler=None,
|
||||
)
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=1024,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=56,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=142,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--freeze_encoder", action="store_true")
|
||||
parser.add_argument("--freeze_embeds", action="store_true")
|
||||
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true", default=False)
|
||||
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
||||
parser.add_argument(
|
||||
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
||||
)
|
||||
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
||||
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class TranslationModule(SummarizationModule):
|
||||
mode = "translation"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["bleu"]
|
||||
default_val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu(preds, target)
|
||||
|
||||
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
check_output_dir(args, expected_items=3)
|
||||
|
||||
if model is None:
|
||||
if "summarization" in args.task:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
if args.early_stopping_patience >= 0:
|
||||
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
else:
|
||||
es_callback = False
|
||||
|
||||
lower_is_better = args.val_metric == "loss"
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(
|
||||
args.output_dir, model.val_metric, args.save_top_k, lower_is_better
|
||||
),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1]
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
23
examples/seq2seq/finetune.sh
Executable file → Normal file
23
examples/seq2seq/finetune.sh
Executable file → Normal file
@@ -1,11 +1,24 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./finetune.sh --help to see all the possible options
|
||||
python finetune.py \
|
||||
python finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
"$@"
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
# Script for verifying that run_bart_sum can be invoked from its directory
|
||||
|
||||
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
|
||||
wget https://cdn-datasets.huggingface.co/summarization/cnn_tiny.tgz
|
||||
tar -xzvf cnn_tiny.tgz
|
||||
rm cnn_tiny.tgz
|
||||
|
||||
export OUTPUT_DIR_NAME=bart_utest_output
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py and testing_utils.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
python finetune.py \
|
||||
--data_dir=cnn_tiny/ \
|
||||
--model_name_or_path=sshleifer/bart-tiny-random \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=2 \
|
||||
--eval_batch_size=2 \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--num_train_epochs=1 \
|
||||
--gpus=0 \
|
||||
--do_train "$@"
|
||||
|
||||
rm -rf cnn_tiny
|
||||
rm -rf $OUTPUT_DIR
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# From appendix C of paper https://arxiv.org/abs/1912.08777
|
||||
# Set --gradient_accumulation_steps so that effective batch size is 256 (2*128, 4*64, 8*32, 16*16)
|
||||
python finetune.py \
|
||||
--learning_rate=1e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.25 \
|
||||
--max_source_length 512 --max_target_length 56 \
|
||||
--freeze_embeds --label_smoothing 0.1 --adafactor --task summarization_xsum \
|
||||
"$@"
|
||||
@@ -1,14 +0,0 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=$CNN_DIR \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=$BS \
|
||||
--eval_batch_size=$BS \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--max_source_length=512 \
|
||||
--max_target_length=56 \
|
||||
--val_check_interval=0.1 --n_val=200 \
|
||||
--do_train --do_predict \
|
||||
"$@"
|
||||
26
examples/seq2seq/finetune_tpu.sh
Normal file
26
examples/seq2seq/finetune_tpu.sh
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export TPU_NUM_CORES=8
|
||||
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./finetune_tpu.sh --help to see all the possible options
|
||||
python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
finetune_trainer.py \
|
||||
--learning_rate=3e-5 \
|
||||
--do_train --do_eval \
|
||||
--evaluation_strategy steps \
|
||||
--prediction_loss_only \
|
||||
--n_val 1000 \
|
||||
"$@"
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import fire
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
|
||||
layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
|
||||
assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
|
||||
dest_layers.load_state_dict(layers_to_copy.state_dict())
|
||||
|
||||
|
||||
LAYERS_TO_COPY = {
|
||||
# maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
|
||||
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
|
||||
12: {
|
||||
1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher
|
||||
2: [0, 6],
|
||||
3: [0, 6, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: list(range(12)),
|
||||
},
|
||||
16: { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 15],
|
||||
3: [0, 8, 15],
|
||||
4: [0, 5, 10, 15],
|
||||
6: [0, 3, 6, 9, 12, 15],
|
||||
8: [0, 2, 4, 6, 8, 10, 12, 15],
|
||||
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
|
||||
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
|
||||
16: list(range(16)),
|
||||
},
|
||||
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
|
||||
}
|
||||
LAYERS_TO_SUPERVISE = {
|
||||
# maps num layers in student -> which teacher layers to copy.
|
||||
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
|
||||
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
|
||||
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
|
||||
}
|
||||
|
||||
|
||||
def pick_layers_to_copy(n_student, n_teacher):
|
||||
try:
|
||||
val = LAYERS_TO_COPY[n_teacher][n_student]
|
||||
return val
|
||||
except KeyError:
|
||||
if n_student != n_teacher:
|
||||
warnings.warn(
|
||||
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
|
||||
)
|
||||
return list(range(n_student))
|
||||
|
||||
|
||||
def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
|
||||
"""Used or the --supervise_forward kwarg"""
|
||||
if n_student > n_teacher:
|
||||
raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
|
||||
elif n_teacher == n_student:
|
||||
return list(range(n_teacher))
|
||||
elif n_student == 1:
|
||||
return [n_teacher - 1]
|
||||
else:
|
||||
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
|
||||
|
||||
|
||||
def create_student_by_copying_alternating_layers(
|
||||
teacher: Union[str, PreTrainedModel],
|
||||
save_path: Union[str, Path] = "student",
|
||||
e: Union[int, None] = None,
|
||||
d: Union[int, None] = None,
|
||||
copy_first_teacher_layers=False,
|
||||
e_layers_to_copy=None,
|
||||
d_layers_to_copy=None,
|
||||
**extra_config_kwargs
|
||||
) -> Tuple[PreTrainedModel, List[int], List[int]]:
|
||||
"""Make a student by copying alternating layers from a teacher, save it to save_path.
|
||||
Args:
|
||||
teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before
|
||||
copying layers
|
||||
save_path: where to save the student, defaults to student directory.
|
||||
e: how many Encoder layers should the student have, default is fully copy of teacher
|
||||
d: how many Decoder layers should the student have, default is fully copy of teacher
|
||||
copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
|
||||
**extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
|
||||
|
||||
Returns:
|
||||
student: new, smaller model. (Also saves it to save_path)
|
||||
e_layers_to_copy: list of which teacher encoder layers were used
|
||||
d_layers_to_copy: list of which teacher decoder layers were used
|
||||
"""
|
||||
_msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher."
|
||||
assert (e is not None) or (d is not None), _msg
|
||||
if isinstance(teacher, str):
|
||||
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
|
||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
|
||||
else:
|
||||
|
||||
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
|
||||
init_kwargs = teacher.config.to_diff_dict()
|
||||
|
||||
try:
|
||||
teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
|
||||
if e is None:
|
||||
e = teacher_e
|
||||
if d is None:
|
||||
d = teacher_d
|
||||
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
|
||||
except AttributeError: # T5
|
||||
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
|
||||
if e is None:
|
||||
e = teacher_e
|
||||
if d is None:
|
||||
d = teacher_d
|
||||
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
|
||||
|
||||
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
|
||||
init_kwargs.update(extra_config_kwargs)
|
||||
|
||||
# Copy weights
|
||||
student_cfg = teacher.config_class(**init_kwargs)
|
||||
student = AutoModelForSeq2SeqLM.from_config(student_cfg)
|
||||
# Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
|
||||
info = student.load_state_dict(teacher.state_dict(), strict=False)
|
||||
assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
|
||||
|
||||
if copy_first_teacher_layers: # Our copying is done. We just log and save
|
||||
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
|
||||
logger.info(
|
||||
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||
)
|
||||
student.save_pretrained(save_path)
|
||||
return student, e_layers_to_copy, d_layers_to_copy
|
||||
|
||||
# Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
|
||||
if e_layers_to_copy is None:
|
||||
e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
|
||||
if d_layers_to_copy is None:
|
||||
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
|
||||
|
||||
try:
|
||||
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
|
||||
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
|
||||
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
|
||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
|
||||
logger.info(
|
||||
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
|
||||
)
|
||||
student.config.init_metadata = dict(
|
||||
teacher_type=teacher.config.model_type,
|
||||
copied_encoder_layers=e_layers_to_copy,
|
||||
copied_decoder_layers=d_layers_to_copy,
|
||||
)
|
||||
student.save_pretrained(save_path)
|
||||
# Save information about copying for easier reproducibility
|
||||
|
||||
return student, e_layers_to_copy, d_layers_to_copy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(create_student_by_copying_alternating_layers)
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fill examples with bitext up to max_tokens without breaking up examples.
|
||||
[['I went', 'yo fui'],
|
||||
['to the store', 'a la tienda']
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
### Saved Pseudo-Labels
|
||||
These are the generations of various large models on various large **training** sets. All in all they took about 200 GPU hours to produce.
|
||||
|
||||
### Available Pseudo-labels
|
||||
| Dataset | Model | Link | Rouge Scores | Notes
|
||||
|---------|-----------------------------|----------------------------------------------------------------------------------------|--------------------|-------------------------------------------------------------------------------------------------------------
|
||||
| XSUM | `facebook/bart-large-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz) | 49.8/28.0/42.5 |
|
||||
| XSUM | `google/pegasus-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/pegasus_xsum.tgz) | 53.3/32.7/46.5 |
|
||||
| XSUM | `facebook/bart-large-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/xsum/xsum_pl2_bart.tgz) | | Bart pseudolabels filtered to those with Rouge2 > 10.0 w GT.
|
||||
| CNN/DM | `sshleifer/pegasus-cnn-ft-v2` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/pegasus_cnn_cnn_pls.tgz) | 47.316/26.65/44.56 | do not worry about the fact that train.source is one line shorter.
|
||||
| CNN/DM | `facebook/bart-large-cnn` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/cnn_bart_pl.tgz) | | 5K (2%) are missing, there should be 282173
|
||||
| CNN/DM | `google/pegasus-xsum` | [download](https://cdn-datasets.huggingface.co/pseudo/cnn_dm/pegasus_xsum_on_cnn.tgz) | 21.5/6.76/25 | extra labels for xsum distillation Used max_source_length=512, (and all other pegasus-xsum configuration).
|
||||
| EN-RO | `Helsinki-NLP/opus-mt-en-ro` | [download](https://cdn-datasets.huggingface.co/pseudo/wmt_en_ro/opus_mt_en_ro.tgz) | |
|
||||
| EN-RO | `facebook/mbart-large-en-ro` | [download](https://cdn-datasets.huggingface.co/pseudo/wmt_en_ro/mbart_large_en_ro.tgz) | |
|
||||
|
||||
|
||||
(EN_RO = WMT 2016 English-Romanian).
|
||||
|
||||
Example Download Command:
|
||||
```bash
|
||||
curl -S https://cdn-datasets.huggingface.co/pseudo/xsum/bart_xsum_pl.tgz | tar -xvz -C .
|
||||
```
|
||||
### Generating New Pseudolabels
|
||||
Here is the command I used to generate the pseudolabels in the second row of the table, after downloading XSUM from [here](https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz).
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
|
||||
--model_name google/pegasus-xsum \
|
||||
--save_dir pegasus_xsum \
|
||||
--data_dir xsum \
|
||||
--bs 8 --sync_timeout 60000 \
|
||||
--max_source_length 512 \
|
||||
--type_path train
|
||||
```
|
||||
|
||||
+ These commands takes a while to run. For example, `pegasus_cnn_cnn_pls.tgz` took 8 hours on 8 GPUs.
|
||||
+ Pegasus does not work in fp16 :(, Bart, mBART and Marian do.
|
||||
+ Even if you have 1 GPU, `run_distributed_eval.py` is 10-20% faster than `run_eval.py` because it uses `SortishSampler` to minimize padding computation.
|
||||
|
||||
### Contributions
|
||||
Feel free to contribute your own pseudolabels via PR. Add a row to this table with a new google drive link (or other command line downloadable link).
|
||||
|
||||
|
||||
20
examples/seq2seq/requirements.txt
Normal file
20
examples/seq2seq/requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
tensorboard
|
||||
scikit-learn
|
||||
seqeval
|
||||
psutil
|
||||
sacrebleu
|
||||
rouge-score
|
||||
tensorflow_datasets
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
faiss-cpu
|
||||
streamlit
|
||||
elasticsearch
|
||||
nltk
|
||||
pandas
|
||||
datasets >= 1.1.3
|
||||
fire
|
||||
pytest
|
||||
conllu
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fire
|
||||
|
||||
from utils import calculate_rouge, save_json
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fire
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@@ -1,4 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import timeout_decorator
|
||||
import torch
|
||||
|
||||
from distillation import SummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from transformers import MarianMTModel
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
|
||||
from utils import load_json
|
||||
|
||||
|
||||
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
|
||||
|
||||
|
||||
class TestMbartCc25Enro(TestCasePlus):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
data_cached = cached_path(
|
||||
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
|
||||
extract_compressed_file=True,
|
||||
)
|
||||
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_model_download(self):
|
||||
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
|
||||
MarianMTModel.from_pretrained(MARIAN_MODEL)
|
||||
|
||||
# @timeout_decorator.timeout(1200)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_train_mbart_cc25_enro_script(self):
|
||||
env_vars_to_replace = {
|
||||
"$MAX_LEN": 64,
|
||||
"$BS": 64,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": self.data_dir,
|
||||
"facebook/mbart-large-cc25": MARIAN_MODEL,
|
||||
# "val_check_interval=0.25": "val_check_interval=1.0",
|
||||
"--learning_rate=3e-5": "--learning_rate 3e-4",
|
||||
"--num_train_epochs 6": "--num_train_epochs 1",
|
||||
}
|
||||
|
||||
# Clean up bash script
|
||||
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
# bash_script = bash_script.replace("--fp16 ", "")
|
||||
args = f"""
|
||||
--output_dir {output_dir}
|
||||
--tokenizer_name Helsinki-NLP/opus-mt-en-ro
|
||||
--sortish_sampler
|
||||
--do_predict
|
||||
--gpus 1
|
||||
--freeze_encoder
|
||||
--n_train 40000
|
||||
--n_val 500
|
||||
--n_test 500
|
||||
--fp16_opt_level O1
|
||||
--num_sanity_val_steps 0
|
||||
--eval_beams 2
|
||||
""".split()
|
||||
# XXX: args.gpus > 1 : handle multi_gpu in the future
|
||||
|
||||
testargs = ["finetune.py"] + bash_script.split() + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
model = main(args)
|
||||
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
|
||||
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
|
||||
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
|
||||
|
||||
# test learning requirements:
|
||||
|
||||
# 1. BLEU improves over the course of training by more than 2 pts
|
||||
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)
|
||||
|
||||
# 2. BLEU finishes above 17
|
||||
self.assertGreater(last_step_stats["val_avg_bleu"], 17)
|
||||
|
||||
# 3. test BLEU and val BLEU within ~1.1 pt.
|
||||
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
|
||||
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
|
||||
|
||||
class TestDistilMarianNoTeacher(TestCasePlus):
|
||||
@timeout_decorator.timeout(600)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_opus_mt_distill_script(self):
|
||||
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
|
||||
env_vars_to_replace = {
|
||||
"--fp16_opt_level=O1": "",
|
||||
"$MAX_LEN": 128,
|
||||
"$BS": 16,
|
||||
"$GAS": 1,
|
||||
"$ENRO_DIR": data_dir,
|
||||
"$m": "sshleifer/student_marian_en_ro_6_1",
|
||||
"val_check_interval=0.25": "val_check_interval=1.0",
|
||||
}
|
||||
|
||||
# Clean up bash script
|
||||
bash_script = (
|
||||
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
|
||||
)
|
||||
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
|
||||
bash_script = bash_script.replace("--fp16 ", " ")
|
||||
|
||||
for k, v in env_vars_to_replace.items():
|
||||
bash_script = bash_script.replace(k, str(v))
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
bash_script = bash_script.replace("--fp16", "")
|
||||
epochs = 6
|
||||
testargs = (
|
||||
["distillation.py"]
|
||||
+ bash_script.split()
|
||||
+ [
|
||||
f"--output_dir={output_dir}",
|
||||
"--gpus=1",
|
||||
"--learning_rate=1e-3",
|
||||
f"--num_train_epochs={epochs}",
|
||||
"--warmup_steps=10",
|
||||
"--val_check_interval=1.0",
|
||||
"--do_predict",
|
||||
]
|
||||
)
|
||||
with patch.object(sys, "argv", testargs):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
# assert args.gpus == gpus THIS BREAKS for multi_gpu
|
||||
|
||||
model = distill_main(args)
|
||||
|
||||
# Check metrics
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
first_step_stats = metrics["val"][0]
|
||||
last_step_stats = metrics["val"][-1]
|
||||
assert len(metrics["val"]) >= (args.max_epochs / args.val_check_interval) # +1 accounts for val_sanity_check
|
||||
|
||||
assert last_step_stats["val_avg_gen_time"] >= 0.01
|
||||
|
||||
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
|
||||
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
|
||||
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
|
||||
# check lightning ckpt can be loaded and has a reasonable statedict
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_path = [x for x in contents if x.endswith(".ckpt")][0]
|
||||
full_path = os.path.join(args.output_dir, ckpt_path)
|
||||
ckpt = torch.load(full_path, map_location="cpu")
|
||||
expected_key = "model.model.decoder.layers.0.encoder_attn_layer_norm.weight"
|
||||
assert expected_key in ckpt["state_dict"]
|
||||
assert ckpt["state_dict"]["model.model.decoder.layers.0.encoder_attn_layer_norm.weight"].dtype == torch.float32
|
||||
|
||||
# TODO: turn on args.do_predict when PL bug fixed.
|
||||
if args.do_predict:
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.txt" in contents
|
||||
# assert len(metrics["val"]) == desired_n_evals
|
||||
assert len(metrics["test"]) == 1
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
1
examples/seq2seq/test_data/test_data
Symbolic link
1
examples/seq2seq/test_data/test_data
Symbolic link
@@ -0,0 +1 @@
|
||||
seq2seq/test_data
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,7 +22,6 @@ from torch.utils.data import DataLoader
|
||||
from pack_dataset import pack_data_dir
|
||||
from parameterized import parameterized
|
||||
from save_len_file import save_len_file
|
||||
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow
|
||||
@@ -17,6 +30,24 @@ from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDat
|
||||
|
||||
BERT_BASE_CASED = "bert-base-cased"
|
||||
PEGASUS_XSUM = "google/pegasus-xsum"
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
content = "\n".join(articles)
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
def make_test_data_dir(tmp_dir):
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestAll(TestCasePlus):
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
@@ -17,11 +31,11 @@ from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .test_seq2seq_examples import MBART_TINY
|
||||
|
||||
|
||||
set_seed(42)
|
||||
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
|
||||
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from make_student import create_student_by_copying_alternating_layers
|
||||
from transformers import AutoConfig
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
TINY_BART = "sshleifer/bart-tiny-random"
|
||||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
|
||||
|
||||
@require_torch
|
||||
class MakeStudentTester(unittest.TestCase):
|
||||
@cached_property
|
||||
def teacher_config(self):
|
||||
return AutoConfig.from_pretrained(TINY_BART)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_valid_t5(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
|
||||
self.assertEqual(student.config.num_hidden_layers, 1)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_asymmetric_t5(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_same_decoder_small_encoder(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
|
||||
self.assertEqual(student.config.encoder_layers, 1)
|
||||
self.assertEqual(student.config.decoder_layers, self.teacher_config.encoder_layers)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_small_enc_small_dec(self):
|
||||
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=1)
|
||||
self.assertEqual(student.config.encoder_layers, 1)
|
||||
self.assertEqual(student.config.decoder_layers, 1)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_raises_assert(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=None, d=None)
|
||||
@@ -1,96 +1,32 @@
|
||||
import argparse
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval import run_generate
|
||||
from run_eval_search import run_search
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||
from transformers.testing_utils import CaptureStdout, TestCasePlus, slow
|
||||
from utils import ROUGE_KEYS
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"max_tokens_per_batch": None,
|
||||
"supervise_forward": True,
|
||||
"normalize_hidden": True,
|
||||
"label_smoothing": 0.2,
|
||||
"eval_max_gen_length": None,
|
||||
"eval_beams": 1,
|
||||
"val_metric": "loss",
|
||||
"save_top_k": 1,
|
||||
"adafactor": True,
|
||||
"early_stopping_patience": 2,
|
||||
"logger_name": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"accumulate_grad_batches": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"lr_scheduler": "linear",
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"max_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
"overwrite_output_dir": False,
|
||||
"student": None,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
@@ -98,187 +34,15 @@ def _dump_articles(path: Path, articles: list):
|
||||
Path(path).open("w").writelines(content)
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
T5_TINIER = "sshleifer/t5-tinier-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
FSMT_TINY = "stas/tiny-wmt19-en-de"
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(tmp_dir):
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
|
||||
_dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(TestCasePlus):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_hub_configs(self):
|
||||
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
||||
|
||||
model_list = HfApi().model_list()
|
||||
org = "sshleifer"
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
||||
failures = []
|
||||
for m in model_ids:
|
||||
if m in allowed_to_be_broken:
|
||||
continue
|
||||
try:
|
||||
AutoConfig.from_pretrained(m)
|
||||
except Exception:
|
||||
failures.append(m)
|
||||
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_checkpointing_with_teacher(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
max_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp() # XXX: not being cleaned up
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
out_path_new = self.get_auto_remove_tmp_dir()
|
||||
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
|
||||
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
|
||||
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
model_computed_loss = model(
|
||||
input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False
|
||||
).loss
|
||||
|
||||
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||
)
|
||||
with self.assertRaises(AssertionError):
|
||||
# TODO: understand why this breaks
|
||||
self.assertEqual(nll_loss, model_computed_loss)
|
||||
|
||||
def test_distill_mbart(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
task="translation",
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
tokenizer_name=MBART_TINY,
|
||||
teacher=MBART_TINY,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
assert model.model.config.model_type == "mbart"
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
|
||||
assert len(all_files) > 2
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher=T5_TINY,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=T5_TINY,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_different_base_models(self):
|
||||
updates = dict(
|
||||
teacher=T5_TINY,
|
||||
student=T5_TINIER,
|
||||
model_name_or_path=T5_TINIER,
|
||||
tokenizer_name=T5_TINIER,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||
assert len(ckpt_files) > 0
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
return model
|
||||
|
||||
|
||||
class TestTheRest(TestCasePlus):
|
||||
def run_eval_tester(self, model):
|
||||
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
|
||||
@@ -365,167 +129,3 @@ class TestTheRest(TestCasePlus):
|
||||
assert w not in cs.out
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
@parameterized.expand(
|
||||
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
|
||||
)
|
||||
def test_finetune(self, model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
|
||||
args_d["label_smoothing"] = 0.1 if task == "translation" else 0
|
||||
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
module = main(args)
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
elif model == FSMT_TINY:
|
||||
fsmt = module.model.model
|
||||
embed_pos = fsmt.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not fsmt.decoder.embed_tokens.weight.requires_grad
|
||||
# check that embeds are not the same
|
||||
assert fsmt.decoder.embed_tokens != fsmt.encoder.embed_tokens
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||
assert isinstance(example_batch, dict)
|
||||
assert len(example_batch) >= 4
|
||||
|
||||
def test_finetune_extra_model_args(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# test models whose config includes the extra_model_args
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
args_d1[p] = 0.5
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
for p in extra_model_params:
|
||||
assert getattr(model.config, p) == 0.5, f"failed to override the model config for param {p}"
|
||||
|
||||
# test models whose config doesn't include the extra_model_args
|
||||
model = T5_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
args = argparse.Namespace(**args_d2)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
model = main(args)
|
||||
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
|
||||
|
||||
def test_finetune_lr_schedulers(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
task = "summarization"
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
|
||||
model = BART_TINY
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
do_predict=False,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
|
||||
# emulate finetune.py
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
||||
args = {"--help": True}
|
||||
|
||||
# --help test
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStdout() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "--help is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = lightning_base.arg_to_scheduler_metavar
|
||||
assert expected in cs.out, "--help is expected to list the supported schedulers"
|
||||
|
||||
# --lr_scheduler=non_existing_scheduler test
|
||||
unsupported_param = "non_existing_scheduler"
|
||||
args = {f"--lr_scheduler={unsupported_param}"}
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
with CaptureStderr() as cs:
|
||||
args = parser.parse_args(args)
|
||||
assert False, "invalid argument is expected to sys.exit"
|
||||
assert excinfo.type == SystemExit
|
||||
expected = f"invalid choice: '{unsupported_param}'"
|
||||
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"
|
||||
|
||||
# --lr_scheduler=existing_scheduler test
|
||||
supported_param = "cosine"
|
||||
args_d1 = args_d.copy()
|
||||
args_d1["lr_scheduler"] = supported_param
|
||||
args = argparse.Namespace(**args_d1)
|
||||
model = main(args)
|
||||
assert (
|
||||
getattr(model.hparams, "lr_scheduler") == supported_param
|
||||
), f"lr_scheduler={supported_param} shouldn't fail"
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, require_torch_gpu, slow
|
||||
|
||||
from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir
|
||||
from .utils import load_json
|
||||
|
||||
|
||||
@@ -21,73 +27,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
def setUpClass(cls):
|
||||
return cls
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu(self):
|
||||
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
overwrite_output_dir=True,
|
||||
sortish_sampler=True,
|
||||
)
|
||||
self._test_distiller_cli_fork(updates, check_contents=False)
|
||||
|
||||
def _test_distiller_cli_fork(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
label_smoothing=0.0,
|
||||
early_stopping_patience=-1,
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
max_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
|
||||
def convert(k, v):
|
||||
if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
|
||||
return ""
|
||||
if v is False or v is None:
|
||||
return ""
|
||||
if v is True: # or len(str(v))==0:
|
||||
return f"--{k}"
|
||||
return f"--{k}={v}"
|
||||
|
||||
cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
|
||||
cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
contents = os.listdir(output_dir)
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||
assert len(ckpt_files) > 0
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
# get the following from the module, (we don't have access to `model` here)
|
||||
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||
val_metric = "rouge2"
|
||||
|
||||
metrics = load_json(metrics_save_path)
|
||||
# {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
|
||||
print(metrics)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_distributed_eval(self):
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export WANDB_PROJECT=distil-marian
|
||||
export BS=64
|
||||
export GAS=1
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export WANDB_PROJECT=distil-marian
|
||||
export BS=64
|
||||
export m=sshleifer/student_marian_en_ro_6_3
|
||||
49
examples/seq2seq/train_distilbart_cnn.sh
Executable file → Normal file
49
examples/seq2seq/train_distilbart_cnn.sh
Executable file → Normal file
@@ -1,24 +1,39 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
export WANDB_PROJECT=distilbart-trainer
|
||||
export BS=32
|
||||
export GAS=1
|
||||
export m=sshleifer/student_cnn_12_6
|
||||
export tok=facebook/bart-large
|
||||
export MAX_TGT_LEN=142
|
||||
|
||||
python finetune.py \
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path $m --tokenizer_name $tok \
|
||||
--data_dir cnn_dm \
|
||||
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 --sortish_sampler \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.25 \
|
||||
--n_val 500 \
|
||||
--num_train_epochs 2 \
|
||||
--freeze_encoder --freeze_embeds --data_dir cnn_dm \
|
||||
--max_target_length 142 --val_max_target_length=142 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--warmup_steps 500 \
|
||||
--output_dir distilbart-cnn-12-6 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=2 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--logging_first_step \
|
||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --sortish_sampler \
|
||||
"$@"
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
python distillation.py \
|
||||
--teacher facebook/bart-large-xsum --data_dir xsum \
|
||||
--tokenizer_name facebook/bart-large-xsum \
|
||||
--student_decoder_layers 6 --student_encoder_layers 12 \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 --fp16_opt_level=O1 \
|
||||
--val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--model_name_or_path IGNORED \
|
||||
--alpha_hid=3. \
|
||||
--train_batch_size=16 --eval_batch_size=16 --gradient_accumulation_steps=2 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs=6 \
|
||||
--warmup_steps 500 \
|
||||
--output_dir distilbart_xsum_12_6 \
|
||||
"$@"
|
||||
48
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file → Normal file
48
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file → Normal file
@@ -1,18 +1,36 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--do_train \
|
||||
--val_check_interval=0.25 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--data_dir $ENRO_DIR \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--freeze_embeds \
|
||||
python finetune_trainer.py \
|
||||
--model_name_or_path=facebook/mbart-large-cc25 \
|
||||
--data_dir $ENRO_DIR \
|
||||
--output_dir mbart_cc25_enro --overwrite_output_dir \
|
||||
--learning_rate=3e-5 \
|
||||
--warmup_steps 500 \
|
||||
--fp16 \
|
||||
--label_smoothing 0.1 \
|
||||
--adam_eps 1e-06 \
|
||||
--src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_embeds \
|
||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||
--max_source_length 128 --max_target_length 128 \
|
||||
--val_max_target_length 128 --test_max_target_length 128 \
|
||||
--sortish_sampler \
|
||||
--num_train_epochs 6 \
|
||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --logging_first_step \
|
||||
--task translation \
|
||||
"$@"
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A simple launcher script for TPU training
|
||||
|
||||
|
||||
Reference in New Issue
Block a user