examples/seq2seq supports translation (#5202)
This commit is contained in:
169
examples/seq2seq/README.md
Normal file
169
examples/seq2seq/README.md
Normal file
@@ -0,0 +1,169 @@
|
||||
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
|
||||
Summarization support is more mature than translation support.
|
||||
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
|
||||
For `bertabs` instructions, see `bertabs/README.md`.
|
||||
|
||||
### Data
|
||||
|
||||
CNN/DailyMail data
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||
tar -xzvf cnn_dm.tgz
|
||||
|
||||
export CNN_DIR=${PWD}/cnn_dm
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
|
||||
|
||||
WMT16 English-Romanian Translation Data:
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
|
||||
tar -xzvf wmt_en_ro.tar.gz
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro
|
||||
```
|
||||
|
||||
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
|
||||
The `.source` files are the input, the `.target` files are the desired output.
|
||||
|
||||
### Evaluation
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||
```
|
||||
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
The following command should work on a 16GB GPU:
|
||||
```bash
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--train_batch_size=1 \
|
||||
--eval_batch_size=1 \
|
||||
--output_dir=xsum_results \
|
||||
--num_train_epochs 1 \
|
||||
--model_name_or_path facebook/bart-large
|
||||
```
|
||||
|
||||
*Note*: The following tips mostly apply to summarization finetuning.
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
|
||||
#### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
|
||||
```bash
|
||||
output_dir
|
||||
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
|
||||
│ ├── config.json
|
||||
│ ├── merges.txt
|
||||
│ ├── pytorch_model.bin
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.json
|
||||
├── git_log.json # repo, branch, and commit hash
|
||||
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
|
||||
├── metrics.json # new validation metrics will continually be appended to this
|
||||
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── test_generations.txt
|
||||
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
|
||||
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
|
||||
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
|
||||
```
|
||||
After training, you can recover the best checkpoint by running
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
|
||||
```
|
||||
|
||||
|
||||
### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||
|
||||
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
|
||||
```bash
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir xsum_frozen_embs \
|
||||
--model_name_or_path facebook/bart-large \
|
||||
--logger wandb_shared \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6 \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
|
||||
```
|
||||
|
||||
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
|
||||
|
||||
|
||||
### Distilbart
|
||||
|
||||
#### No Teacher Distillation
|
||||
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
|
||||
You don't even need `distillation.py`.
|
||||
|
||||
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
|
||||
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
|
||||
The command that produced `sshleifer/distilbart-cnn-12-6` is
|
||||
```bash
|
||||
./train_distilbart_cnn.sh
|
||||
```
|
||||
runtime: 6H on NVIDIA RTX 24GB GPU
|
||||
|
||||
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
|
||||
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
|
||||
because you will have the same hyperparameters logged in every run.
|
||||
|
||||
#### With a teacher
|
||||
*Note* only BART variants are supported
|
||||
|
||||
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
|
||||
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
|
||||
|
||||
The command that produced `sshleifer/distilbart-xsum-12-6` is:
|
||||
|
||||
```bash
|
||||
./train_distilbart_xsum.sh
|
||||
```
|
||||
|
||||
runtime: 13H on V-100 16GB GPU.
|
||||
|
||||
### Contributing
|
||||
- follow the standard contributing guidelines and code of conduct.
|
||||
- add tests to `test_seq2seq_examples.py`
|
||||
- To run only the seq2seq tests, you must be in the root of the repository and run:
|
||||
```bash
|
||||
pytest examples/seq2seq/
|
||||
```
|
||||
@@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati
|
||||
git clone https://github.com/huggingface/transformers && cd transformers
|
||||
pip install .
|
||||
pip install nltk py-rouge
|
||||
cd examples/summarization
|
||||
cd examples/seq2seq/bertabs
|
||||
```
|
||||
|
||||
## Reproduce the authors' ROUGE score
|
||||
@@ -32,9 +32,12 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
|
||||
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
@@ -63,20 +66,25 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
return self._write_logs(trainer, pl_module, "val")
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
|
||||
def get_rouge2_checkpoint_callback(output_dir):
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation ROUGE2 score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
|
||||
monitor="val_rouge",
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=1,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
@@ -39,13 +39,12 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
class SummarizationDistiller(SummarizationModule):
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
|
||||
|
||||
def __init__(self, hparams):
|
||||
assert Path(hparams.data_dir).exists()
|
||||
|
||||
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
|
||||
student, student_cfg, teacher = self.pre_init(hparams)
|
||||
|
||||
super().__init__(hparams, model=student, config=student_cfg)
|
||||
self.teacher = teacher
|
||||
@@ -73,12 +72,15 @@ class SummarizationDistiller(SummarizationModule):
|
||||
del self.teacher.model.encoder
|
||||
|
||||
def pre_init(self, hparams):
|
||||
# Dump empty student model at a path, then call from_pretrained on it
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||
student_updates = {
|
||||
"decoder_layers": hparams.student_decoder_layers,
|
||||
"encoder_layers": hparams.student_encoder_layers,
|
||||
}
|
||||
if hparams.length_penalty != -1:
|
||||
student_updates["length_penalty"] = hparams.length_penalty
|
||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
@@ -89,9 +91,13 @@ class SummarizationDistiller(SummarizationModule):
|
||||
student_cfg = BartConfig(**kw)
|
||||
student = BartForConditionalGeneration(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||
return d_layers_to_copy, student, student_cfg, teacher
|
||||
student.save_pretrained(save_dir)
|
||||
hparams.model_name_or_path = str(save_dir)
|
||||
return student, student_cfg, teacher
|
||||
|
||||
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
|
||||
if teacher.config.model_type == "t5":
|
||||
@@ -154,7 +160,6 @@ class SummarizationDistiller(SummarizationModule):
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
@@ -180,18 +185,11 @@ class SummarizationDistiller(SummarizationModule):
|
||||
# parser.add_argument("--alpha_cos", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
|
||||
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
|
||||
parser.add_argument(
|
||||
"--student_decoder_layers", default=12, type=int, required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--student_encoder_layers", default=12, type=int, required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_teacher", action="store_true", default=False,
|
||||
)
|
||||
parser.add_argument( # TODO: remove
|
||||
"--enc_only", action="store_true", default=False,
|
||||
)
|
||||
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
|
||||
parser.add_argument("--no_teacher", action="store_true", default=False)
|
||||
parser.add_argument("--length_penalty", type=float, default=-1)
|
||||
|
||||
return parser
|
||||
|
||||
def _step(self, batch):
|
||||
@@ -269,12 +267,14 @@ class SummarizationDistiller(SummarizationModule):
|
||||
return sum(hidden_losses)
|
||||
|
||||
|
||||
class T5SummarizationDistiller(SummarizationDistiller):
|
||||
class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
def pre_init(self, hparams):
|
||||
raise NotImplementedError("T5 Distillation does not work yet")
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||
n_layer = hparams.student_decoder_layers
|
||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
|
||||
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
|
||||
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
|
||||
student_updates = {"num_layers": n_layer}
|
||||
@@ -291,8 +291,13 @@ class T5SummarizationDistiller(SummarizationDistiller):
|
||||
Path(hparams.output_dir).mkdir(exist_ok=True)
|
||||
task_specific_params = student.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
student.config.update(task_specific_params.get("summarization", {}))
|
||||
return d_layers_to_copy, student, student_cfg, teacher
|
||||
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
save_dir.mkdir(exist_ok=True)
|
||||
|
||||
student.save_pretrained(save_dir)
|
||||
hparams.model_name_or_path = str(save_dir)
|
||||
return student, student_cfg, teacher
|
||||
|
||||
def freeze_embeds(self):
|
||||
freeze_params(self.model.shared)
|
||||
@@ -386,7 +391,7 @@ def create_module(args):
|
||||
elif args.enc_only:
|
||||
raise ValueError("Deleted that")
|
||||
else:
|
||||
module_cls = SummarizationDistiller
|
||||
module_cls = BartSummarizationDistiller
|
||||
args.setup_cls: str = module_cls.__name__
|
||||
model = module_cls(args)
|
||||
return model
|
||||
@@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
|
||||
def get_layers_to_copy(n_to_get, tot):
|
||||
all_layers = list(range(tot))
|
||||
if tot == 12: # Alternating for special cases
|
||||
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
1: [11],
|
||||
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
|
||||
1: [0],
|
||||
2: [0, 6],
|
||||
3: [0, 6, 11],
|
||||
2: [0, 11],
|
||||
4: [0, 4, 8, 11],
|
||||
6: [0, 2, 4, 7, 9, 11],
|
||||
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
|
||||
12: all_layers,
|
||||
}
|
||||
return layers_to_copy[n_to_get]
|
||||
else:
|
||||
return all_layers[:n_to_get]
|
||||
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
|
||||
|
||||
|
||||
def distill_main(args):
|
||||
@@ -443,7 +448,7 @@ def distill_main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
distill_main(args)
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
@@ -23,12 +24,14 @@ try:
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
@@ -37,12 +40,14 @@ except ImportError:
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,15 +55,18 @@ logger = logging.getLogger(__name__)
|
||||
class SummarizationModule(BaseTransformer):
|
||||
mode = "summarization"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ROUGE_KEYS
|
||||
val_metric = "rouge2"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
||||
use_task_specific_params(self.model, "summarization")
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = {"train": [], "val": [], "test": []}
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
@@ -89,12 +97,12 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
if self.model.config.model_type == "bart":
|
||||
try:
|
||||
freeze_params(self.model.model.shared)
|
||||
for d in [self.model.model.encoder, self.model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
except AttributeError:
|
||||
freeze_params(self.model.shared)
|
||||
for d in [self.model.encoder, self.model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
@@ -130,19 +138,22 @@ class SummarizationModule(BaseTransformer):
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
|
||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
|
||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
|
||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
||||
rouges.update({k: v.item() for k, v in losses.items()})
|
||||
losses.update(rouges)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
preds = flatten_list([x["preds"] for x in outputs])
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
|
||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
|
||||
|
||||
def save_metrics(self, metrics, prefix) -> None:
|
||||
self.metrics[prefix].append(metrics)
|
||||
pickle_save(self.metrics, self.metrics_save_path)
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_rouge(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
@@ -154,7 +165,7 @@ class SummarizationModule(BaseTransformer):
|
||||
target = self.ids_to_clean_text(y)
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
rouge: Dict = calculate_rouge(preds, target)
|
||||
rouge: Dict = self.calc_generative_metrics(preds, target)
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
||||
return base_metrics
|
||||
@@ -259,15 +270,33 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class TranslationModule(SummarizationModule):
|
||||
mode = "translation"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
|
||||
def main(args, model=None) -> SummarizationModule:
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if model is None:
|
||||
model: BaseTransformer = SummarizationModule(args)
|
||||
if args.task == "summarization":
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger == "default"
|
||||
or args.fast_dev_run
|
||||
@@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule:
|
||||
elif args.logger == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=dataset)
|
||||
|
||||
elif args.logger == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
|
||||
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
@@ -1,13 +1,8 @@
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
|
||||
# --model_name_or_path=t5-base for t5
|
||||
|
||||
# the proper usage is documented in the README
|
||||
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
|
||||
python finetune.py \
|
||||
--model_name_or_path=facebook/bart-large \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
@@ -16,5 +11,4 @@ python finetune.py \
|
||||
--n_val 1000 \
|
||||
--val_check_interval 0.1 \
|
||||
--sortish_sampler \
|
||||
--max_target_length=56 \
|
||||
$@
|
||||
0
examples/summarization/finetune_bart_tiny.sh → examples/seq2seq/finetune_bart_tiny.sh
Normal file → Executable file
0
examples/summarization/finetune_bart_tiny.sh → examples/seq2seq/finetune_bart_tiny.sh
Normal file → Executable file
0
examples/summarization/finetune_t5.sh → examples/seq2seq/finetune_t5.sh
Normal file → Executable file
0
examples/summarization/finetune_t5.sh → examples/seq2seq/finetune_t5.sh
Normal file → Executable file
@@ -1,5 +1,3 @@
|
||||
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import calculate_rouge, use_task_specific_params
|
||||
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||
except ImportError:
|
||||
from finetune import calculate_rouge, use_task_specific_params
|
||||
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -22,8 +22,14 @@ def chunks(lst, n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
||||
def generate_summaries_or_translations(
|
||||
examples: list,
|
||||
out_file: str,
|
||||
model_name: str,
|
||||
batch_size: int = 8,
|
||||
device: str = DEFAULT_DEVICE,
|
||||
fp16=False,
|
||||
**gen_kwargs,
|
||||
) -> None:
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model_name = str(model_name)
|
||||
@@ -39,11 +45,10 @@ def generate_summaries(
|
||||
for batch in tqdm(list(chunks(examples, batch_size))):
|
||||
if "t5" in model_name:
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
|
||||
device
|
||||
)
|
||||
summaries = model.generate(**dct)
|
||||
|
||||
batch = tokenizer.batch_encode_plus(
|
||||
batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
|
||||
).to(device)
|
||||
summaries = model.generate(**batch, **gen_kwargs)
|
||||
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for hypothesis in dec:
|
||||
fout.write(hypothesis + "\n")
|
||||
@@ -57,22 +62,26 @@ def run_generate():
|
||||
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
|
||||
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
|
||||
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
|
||||
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
|
||||
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
|
||||
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
args = parser.parse_args()
|
||||
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
|
||||
|
||||
generate_summaries(
|
||||
generate_summaries_or_translations(
|
||||
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
|
||||
)
|
||||
if args.score_path is not None:
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
|
||||
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
|
||||
scores = {}
|
||||
if args.reference_path is not None:
|
||||
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
|
||||
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
|
||||
|
||||
rouge: dict = calculate_rouge(output_lns, reference_lns)
|
||||
|
||||
json.dump(rouge, open("score_path", "w+"))
|
||||
scores: dict = score_fn(output_lns, reference_lns)
|
||||
if args.score_path is not None:
|
||||
json.dump(scores, open("score_path", "w+"))
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
252
examples/seq2seq/test_seq2seq_examples.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .utils import SummarizationDataset, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"length_penalty": 0.5,
|
||||
"cache_dir": "",
|
||||
"task": "summarization",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": CUDA_AVAILABLE,
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if CUDA_AVAILABLE else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 0.3,
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"num_train_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_loss_encoder": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
with path.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
BART_TINY = "sshleifer/bart-tiny-random"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
|
||||
|
||||
def make_test_data_dir(**kwargs):
|
||||
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
||||
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||
def test_multigpu(self):
|
||||
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_checkpointing_with_teacher(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), 2)
|
||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp()
|
||||
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
@unittest.skip("T5 distillation is broken at the moment")
|
||||
def test_distill_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher=T5_TINY,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=T5_TINY,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def _test_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
self.assertIn(ckpt_name, contents)
|
||||
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
metrics = load_json(model.metrics_save_path)
|
||||
last_step_stats = metrics["val"][-1]
|
||||
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
||||
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
||||
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["test"]), 1)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval_bart(model):
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(tmp, articles)
|
||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=model,
|
||||
tokenizer_name=None,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
|
||||
)
|
||||
def test_dataset(tok):
|
||||
tokenizer = AutoTokenizer.from_pretrained(tok)
|
||||
tmp_dir = make_test_data_dir()
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
||||
# show that articles were trimmed.
|
||||
assert batch["input_ids"].shape[1] == max_len_source
|
||||
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
||||
# show that targets were truncated
|
||||
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
||||
assert max_len_target > trunc_target # Truncated
|
||||
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
24
examples/seq2seq/train_distilbart_cnn.sh
Executable file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
export BS=32
|
||||
export GAS=1
|
||||
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.25 \
|
||||
--n_val 500 \
|
||||
--num_train_epochs 2 \
|
||||
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
||||
--max_target_length 142 --val_max_target_length=142 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--data_dir $CNN_DIR \
|
||||
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--output_dir distilbart-cnn-12-6 \
|
||||
$@
|
||||
|
||||
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
20
examples/seq2seq/train_distilbart_xsum.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export BS=16
|
||||
export GAS=2
|
||||
python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 \
|
||||
--val_check_interval 0.1 --n_val 1000 \
|
||||
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
|
||||
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
|
||||
--student_decoder_layers 6 --student_encoder_layers 12 \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--model_name_or_path IGNORED \
|
||||
--alpha_hid=3. --length_penalty=0.5 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--output_dir distilbart_xsum_12_6 \
|
||||
$@
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
@@ -41,7 +42,7 @@ def encode_file(
|
||||
examples = []
|
||||
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
||||
tokenized = tokenizer.batch_encode_plus(
|
||||
[text], # DONT ADD SPACES
|
||||
[text],
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
add_prefix_space=True,
|
||||
@@ -54,11 +55,13 @@ def encode_file(
|
||||
return examples
|
||||
|
||||
|
||||
def lmap(f, x):
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
T5_PREFIX = "summarize: " # HACK, fixme
|
||||
def calculate_bleu_score(output_lns, refs_lns) -> dict:
|
||||
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
|
||||
|
||||
|
||||
def trim_batch(
|
||||
@@ -95,6 +98,8 @@ class SummarizationDataset(Dataset):
|
||||
tok_name=tok_name,
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
@@ -189,14 +194,20 @@ def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str):
|
||||
"""
|
||||
Log commit info.
|
||||
"""
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
def save_json(content, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=4)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
@@ -1,70 +0,0 @@
|
||||
### Data
|
||||
|
||||
CNN/DailyMail data
|
||||
```bash
|
||||
cd examples/summarization
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
|
||||
tar -xzvf cnn_dm.tgz
|
||||
export CNN_DIR=${PWD}/cnn_dm
|
||||
```
|
||||
|
||||
this should make a directory called cnn_dm/ with files like `test.source`.
|
||||
To use your own data, copy that files format. Each article to be summarized is on its own line.
|
||||
|
||||
XSUM Data:
|
||||
```bash
|
||||
cd examples/summarization
|
||||
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
|
||||
tar -xzvf xsum.tar.gz
|
||||
export XSUM_DIR=${PWD}/xsum
|
||||
```
|
||||
|
||||
|
||||
### Evaluation
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
|
||||
```
|
||||
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
|
||||
### Training
|
||||
Run/modify `finetune.sh`
|
||||
|
||||
The following command should work on a 16GB GPU:
|
||||
```bash
|
||||
export me=`git config user.name`
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--train_batch_size=1 \
|
||||
--eval_batch_size=1 \
|
||||
--output_dir="$me"_xsum_results \
|
||||
--num_train_epochs 1
|
||||
```
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
|
||||
### XSUM Shared Task
|
||||
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
|
||||
Here is an example command
|
||||
```bash
|
||||
export me=`git config user.name`
|
||||
./finetune.sh \
|
||||
--data_dir $XSUM_DIR \
|
||||
--output_dir "$me"_xsum_frozen_embs \
|
||||
--logger wandb_shared \
|
||||
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
|
||||
--num_train_epochs 6
|
||||
```
|
||||
|
||||
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)
|
||||
@@ -1,267 +0,0 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import main
|
||||
from .run_eval import generate_summaries, run_generate
|
||||
from .utils import SummarizationDataset, lmap, pickle_load
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
FP16_EVER = False
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
"tgt_suffix": "",
|
||||
"resume_from_checkpoint": None,
|
||||
"sortish_sampler": True,
|
||||
"student_decoder_layers": 1,
|
||||
"val_check_interval": 1.0,
|
||||
"output_dir": "",
|
||||
"fp16": False,
|
||||
"no_teacher": False,
|
||||
"fp16_opt_level": "O1",
|
||||
"gpus": 1 if torch.cuda.is_available() else 0,
|
||||
"n_tpu_cores": 0,
|
||||
"max_grad_norm": 1.0,
|
||||
"do_train": True,
|
||||
"do_predict": True,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"server_ip": "",
|
||||
"server_port": "",
|
||||
"seed": 42,
|
||||
"model_type": "bart",
|
||||
"model_name_or_path": "sshleifer/bart-tiny-random",
|
||||
"config_name": "",
|
||||
"tokenizer_name": "facebook/bart-large",
|
||||
"cache_dir": "",
|
||||
"do_lower_case": False,
|
||||
"learning_rate": 3e-05,
|
||||
"weight_decay": 0.0,
|
||||
"adam_epsilon": 1e-08,
|
||||
"warmup_steps": 0,
|
||||
"num_train_epochs": 1,
|
||||
"train_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"max_source_length": 12,
|
||||
"max_target_length": 12,
|
||||
"val_max_target_length": 12,
|
||||
"test_max_target_length": 12,
|
||||
"fast_dev_run": False,
|
||||
"no_cache": False,
|
||||
"n_train": -1,
|
||||
"n_val": -1,
|
||||
"n_test": -1,
|
||||
"student_encoder_layers": 1,
|
||||
"alpha_loss_encoder": 0.0,
|
||||
"freeze_encoder": False,
|
||||
"auto_scale_batch_size": False,
|
||||
}
|
||||
|
||||
|
||||
def _dump_articles(path: Path, articles: list):
|
||||
with path.open("w") as f:
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
MSG = "T5 is broken at the moment"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
|
||||
|
||||
def make_test_data_dir():
|
||||
tmp_dir = Path(tempfile.gettempdir())
|
||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
for split in ["train", "val", "test"]:
|
||||
_dump_articles((tmp_dir / f"{split}.source"), articles)
|
||||
_dump_articles((tmp_dir / f"{split}.target"), summaries)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
|
||||
def test_bdc_multigpu(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=False,
|
||||
fp16_opt_level="O1",
|
||||
fp16=FP16_EVER,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_t5_train(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1 if torch.cuda.is_available() else 0,
|
||||
model_type="t5",
|
||||
model_name_or_path=T5_TINY,
|
||||
do_train=True,
|
||||
do_predict=True,
|
||||
tokenizer_name=T5_TINY,
|
||||
no_teacher=True,
|
||||
alpha_hid=2.0,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_no_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_yes_teacher(self):
|
||||
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_checkpointing(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
num_train_epochs=4,
|
||||
val_check_interval=0.25,
|
||||
alpha_hid=2.0,
|
||||
)
|
||||
model = self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(transformer_ckpts), len(ckpts))
|
||||
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
||||
self.assertEqual(len(new_transformer_ckpts), 1)
|
||||
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
||||
out_path = tempfile.mktemp()
|
||||
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
|
||||
self.assertTrue(Path(out_path).exists())
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
alpha_mlm=0.2,
|
||||
alpha_ce=0.8,
|
||||
do_predict=True,
|
||||
gpus=1 if torch.cuda.is_available() else 0,
|
||||
model_name_or_path="sshleifer/tinier_bart",
|
||||
teacher=CHEAP_ARGS["model_name_or_path"],
|
||||
val_check_interval=0.5,
|
||||
alpha_encoder_loss=0.4,
|
||||
)
|
||||
default_updates.update(updates)
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
|
||||
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
||||
model = distill_main(argparse.Namespace(**args_d))
|
||||
if not check_contents:
|
||||
return model
|
||||
contents = os.listdir(output_dir)
|
||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
||||
contents = {os.path.basename(p) for p in contents}
|
||||
self.assertIn(ckpt_name, contents)
|
||||
self.assertIn("metrics.pkl", contents)
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("val_generations_00001.txt", contents)
|
||||
self.assertIn("val_results_00001.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
|
||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
|
||||
return model
|
||||
|
||||
|
||||
class TestBartExamples(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
||||
return cls
|
||||
|
||||
def test_bart_cnn_cli(self):
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(tmp, articles)
|
||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
os.remove(Path(output_file_name))
|
||||
|
||||
def test_t5_run_sum_cli(self):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=None, # T5_TINY,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
gpus=0,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
|
||||
def test_bart_summarization_dataset(self):
|
||||
tmp_dir = Path(tempfile.gettempdir())
|
||||
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
|
||||
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
||||
_dump_articles((tmp_dir / "train.source"), articles)
|
||||
_dump_articles((tmp_dir / "train.target"), summaries)
|
||||
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
|
||||
# show that articles were trimmed.
|
||||
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
|
||||
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
|
||||
|
||||
# show that targets were truncated
|
||||
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
|
||||
self.assertGreater(max_len_target, trunc_target) # Truncated
|
||||
|
||||
|
||||
def list_to_text_file(lst, path):
|
||||
dest = Path(path)
|
||||
dest.open("w+").writelines(lst)
|
||||
@@ -1,51 +0,0 @@
|
||||
***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points***
|
||||
|
||||
### Intro
|
||||
|
||||
This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be
|
||||
evaluated on the WMT English-German dataset.
|
||||
|
||||
### Get the WMT Data
|
||||
|
||||
To be able to reproduce the authors' results on WMT English to German, you first need to download
|
||||
the WMT14 en-de news datasets.
|
||||
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via:
|
||||
|
||||
```bash
|
||||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en
|
||||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de
|
||||
```
|
||||
|
||||
You should have 2737 sentences in each file. You can verify this by running:
|
||||
|
||||
```bash
|
||||
wc -l newstest2014.en # should give 2737
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters:
|
||||
|
||||
Get the longest and shortest sentence:
|
||||
|
||||
```bash
|
||||
awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word
|
||||
awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words
|
||||
```
|
||||
|
||||
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
|
||||
We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`.
|
||||
|
||||
To create translation for each in dataset and get a final BLEU score, run:
|
||||
```bash
|
||||
python evaluate_wmt.py <path_to_newstest2014.en> newstest2014_de_translations.txt <path_to_newstest2014.de> newsstest2014_en_de_bleu.txt
|
||||
```
|
||||
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
|
||||
### Where is the code?
|
||||
The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples.
|
||||
|
||||
### BLEU Scores
|
||||
|
||||
The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost.
|
||||
To get the BLEU score we used
|
||||
@@ -1,103 +0,0 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from sacrebleu import corpus_bleu
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def generate_translations(lns, output_file_path, model_size, batch_size, device):
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_size)
|
||||
model.to(device)
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained(model_size)
|
||||
|
||||
# update config with summarization specific params
|
||||
task_specific_params = model.config.task_specific_params
|
||||
if task_specific_params is not None:
|
||||
model.config.update(task_specific_params.get("translation_en_to_de", {}))
|
||||
|
||||
with Path(output_file_path).open("w") as output_file:
|
||||
for batch in tqdm(list(chunks(lns, batch_size))):
|
||||
batch = [model.config.prefix + text for text in batch]
|
||||
|
||||
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
|
||||
|
||||
input_ids = dct["input_ids"].to(device)
|
||||
attention_mask = dct["attention_mask"].to(device)
|
||||
|
||||
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||
dec = [
|
||||
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations
|
||||
]
|
||||
|
||||
for hypothesis in dec:
|
||||
output_file.write(hypothesis + "\n")
|
||||
|
||||
|
||||
def calculate_bleu_score(output_lns, refs_lns, score_path):
|
||||
bleu = corpus_bleu(output_lns, [refs_lns])
|
||||
result = "BLEU score: {}".format(bleu.score)
|
||||
with Path(score_path).open("w") as score_file:
|
||||
score_file.write(result)
|
||||
|
||||
|
||||
def run_generate():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"model_size",
|
||||
type=str,
|
||||
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
|
||||
default="t5-base",
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_path", type=str, help="like wmt/newstest2014.en",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=str, help="where to save translation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"reference_path", type=str, help="like wmt/newstest2014.de",
|
||||
)
|
||||
parser.add_argument(
|
||||
"score_path", type=str, help="where to save the bleu score",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
|
||||
dash_pattern = (" ##AT##-##AT## ", "-")
|
||||
|
||||
# Read input lines into python
|
||||
with open(args.input_path, "r") as input_file:
|
||||
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()]
|
||||
|
||||
generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
|
||||
|
||||
# Read generated lines into python
|
||||
with open(args.output_path, "r") as output_file:
|
||||
output_lns = [x.strip() for x in output_file.readlines()]
|
||||
|
||||
# Read reference lines into python
|
||||
with open(args.reference_path, "r") as reference_file:
|
||||
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()]
|
||||
|
||||
calculate_bleu_score(output_lns, refs_lns, args.score_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_generate()
|
||||
@@ -1,50 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from .evaluate_wmt import run_generate
|
||||
|
||||
|
||||
text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."]
|
||||
|
||||
output_file_name = "output_t5_trans.txt"
|
||||
score_file_name = "score_t5_trans.txt"
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class TestT5Examples(unittest.TestCase):
|
||||
def test_t5_cli(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo"
|
||||
with tmp_source.open("w") as f:
|
||||
f.write("\n".join(text))
|
||||
|
||||
tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target"
|
||||
with tmp_target.open("w") as f:
|
||||
f.write("\n".join(translation))
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
|
||||
score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"
|
||||
|
||||
testargs = [
|
||||
"evaluate_wmt.py",
|
||||
"patrickvonplaten/t5-tiny-random",
|
||||
str(tmp_source),
|
||||
str(output_file_name),
|
||||
str(tmp_target),
|
||||
str(score_file_name),
|
||||
]
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
self.assertTrue(Path(output_file_name).exists())
|
||||
self.assertTrue(Path(score_file_name).exists())
|
||||
Reference in New Issue
Block a user