examples/seq2seq supports translation (#5202)

This commit is contained in:
Sam Shleifer
2020-06-24 23:58:11 -04:00
committed by GitHub
parent d12ceb48ba
commit 40457bcebb
32 changed files with 626 additions and 636 deletions

169
examples/seq2seq/README.md Normal file
View File

@@ -0,0 +1,169 @@
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
Summarization support is more mature than translation support.
Please tag @sshleifer with any issues/unexpected behaviors, or send a PR!
For `bertabs` instructions, see `bertabs/README.md`.
### Data
CNN/DailyMail data
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
WMT16 English-Romanian Translation Data:
```bash
cd examples/seq2seq
wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export ENRO_DIR=${PWD}/wmt_en_ro
```
If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target.
The `.source` files are the input, the `.target` files are the desired output.
### Evaluation
To create summaries for each article in dataset, run:
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Summarization Finetuning
Run/modify `finetune.sh`
The following command should work on a 16GB GPU:
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir=xsum_results \
--num_train_epochs 1 \
--model_name_or_path facebook/bart-large
```
*Note*: The following tips mostly apply to summarization finetuning.
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
#### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
```bash
output_dir
├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below
│   ├── config.json
│   ├── merges.txt
│   ├── pytorch_model.bin
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── git_log.json # repo, branch, and commit hash
├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score.
├── metrics.json # new validation metrics will continually be appended to this
├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned.
│   ├── config.json
│   └── pytorch_model.bin
├── test_generations.txt
# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done
├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test']
├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly.
```
After training, you can recover the best checkpoint by running
```python
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir xsum_frozen_embs \
--model_name_or_path facebook/bart-large \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100
```
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
### Distilbart
#### No Teacher Distillation
To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student.
You don't even need `distillation.py`.
Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes.
They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`)
The command that produced `sshleifer/distilbart-cnn-12-6` is
```bash
./train_distilbart_cnn.sh
```
runtime: 6H on NVIDIA RTX 24GB GPU
*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`.
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyperparameters logged in every run.
#### With a teacher
*Note* only BART variants are supported
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash
./train_distilbart_xsum.sh
```
runtime: 13H on V-100 16GB GPU.
### Contributing
- follow the standard contributing guidelines and code of conduct.
- add tests to `test_seq2seq_examples.py`
- To run only the seq2seq tests, you must be in the root of the repository and run:
```bash
pytest examples/seq2seq/
```

View File

@@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati
git clone https://github.com/huggingface/transformers && cd transformers
pip install .
pip install nltk py-rouge
cd examples/summarization
cd examples/seq2seq/bertabs
```
## Reproduce the authors' ROUGE score

View File

@@ -32,9 +32,12 @@ class Seq2SeqLoggingCallback(pl.Callback):
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt"
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
@@ -63,20 +66,25 @@ class Seq2SeqLoggingCallback(pl.Callback):
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "val")
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
return self._write_logs(trainer, pl_module, "test")
def get_rouge2_checkpoint_callback(output_dir):
def get_checkpoint_callback(output_dir, metric):
"""Saves the best model by validation ROUGE2 score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"),
monitor="val_rouge",
filepath=os.path.join(output_dir, exp),
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.

View File

@@ -39,13 +39,12 @@ except ImportError:
)
class SummarizationDistiller(SummarizationModule):
class BartSummarizationDistiller(SummarizationModule):
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams):
assert Path(hparams.data_dir).exists()
d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams)
student, student_cfg, teacher = self.pre_init(hparams)
super().__init__(hparams, model=student, config=student_cfg)
self.teacher = teacher
@@ -73,12 +72,15 @@ class SummarizationDistiller(SummarizationModule):
del self.teacher.model.encoder
def pre_init(self, hparams):
# Dump empty student model at a path, then call from_pretrained on it
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
@@ -89,9 +91,13 @@ class SummarizationDistiller(SummarizationModule):
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True)
return d_layers_to_copy, student, student_cfg, teacher
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
if teacher.config.model_type == "t5":
@@ -154,7 +160,6 @@ class SummarizationDistiller(SummarizationModule):
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
@@ -180,18 +185,11 @@ class SummarizationDistiller(SummarizationModule):
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument(
"--student_decoder_layers", default=12, type=int, required=False,
)
parser.add_argument(
"--student_encoder_layers", default=12, type=int, required=False,
)
parser.add_argument(
"--no_teacher", action="store_true", default=False,
)
parser.add_argument( # TODO: remove
"--enc_only", action="store_true", default=False,
)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
return parser
def _step(self, batch):
@@ -269,12 +267,14 @@ class SummarizationDistiller(SummarizationModule):
return sum(hidden_losses)
class T5SummarizationDistiller(SummarizationDistiller):
class T5SummarizationDistiller(BartSummarizationDistiller):
def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet")
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
student_updates = {"num_layers": n_layer}
@@ -291,8 +291,13 @@ class T5SummarizationDistiller(SummarizationDistiller):
Path(hparams.output_dir).mkdir(exist_ok=True)
task_specific_params = student.config.task_specific_params
if task_specific_params is not None:
student.config.update(task_specific_params.get("summarization", {}))
return d_layers_to_copy, student, student_cfg, teacher
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def freeze_embeds(self):
freeze_params(self.model.shared)
@@ -386,7 +391,7 @@ def create_module(args):
elif args.enc_only:
raise ValueError("Deleted that")
else:
module_cls = SummarizationDistiller
module_cls = BartSummarizationDistiller
args.setup_cls: str = module_cls.__name__
model = module_cls(args)
return model
@@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
def get_layers_to_copy(n_to_get, tot):
all_layers = list(range(tot))
if tot == 12: # Alternating for special cases
layers_to_copy = { # maps # layers in student -> which teacher layers to copy
6: [0, 2, 4, 7, 9, 11],
1: [11],
layers_to_copy = { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 6],
3: [0, 6, 11],
2: [0, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: all_layers,
}
return layers_to_copy[n_to_get]
else:
return all_layers[:n_to_get]
return all_layers[:n_to_get] # TODO: better version on theseus-bart branch
def distill_main(args):
@@ -443,7 +448,7 @@ def distill_main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
distill_main(args)

View File

@@ -3,6 +3,7 @@ import glob
import logging
import os
import time
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -23,12 +24,14 @@ try:
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
)
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError:
from utils import (
use_task_specific_params,
@@ -37,12 +40,14 @@ except ImportError:
flatten_list,
pickle_save,
save_git_info,
save_json,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
)
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
logger = logging.getLogger(__name__)
@@ -50,15 +55,18 @@ logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
metric_names = ROUGE_KEYS
val_metric = "rouge2"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = {"train": [], "val": [], "test": []}
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
@@ -89,12 +97,12 @@ class SummarizationModule(BaseTransformer):
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model.config.model_type == "bart":
try:
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
except AttributeError:
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
@@ -130,19 +138,22 @@ class SummarizationModule(BaseTransformer):
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
def save_metrics(self, metrics, prefix) -> None:
self.metrics[prefix].append(metrics)
pickle_save(self.metrics, self.metrics_save_path)
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_rouge(preds, target)
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
@@ -154,7 +165,7 @@ class SummarizationModule(BaseTransformer):
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = calculate_rouge(preds, target)
rouge: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
@@ -259,15 +270,33 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
return parser
class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None:
model: BaseTransformer = SummarizationModule(args)
if args.task == "summarization":
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if (
args.logger == "default"
or args.fast_dev_run
@@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule:
elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
logger=logger,
# TODO: early stopping callback seems messed up
)

View File

@@ -1,13 +1,8 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# --model_name_or_path=t5-base for t5
# the proper usage is documented in the README
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
python finetune.py \
--model_name_or_path=facebook/bart-large \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
@@ -16,5 +11,4 @@ python finetune.py \
--n_val 1000 \
--val_check_interval 0.1 \
--sortish_sampler \
--max_target_length=56 \
$@

View File

@@ -1,5 +1,3 @@
#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

View File

@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
try:
from .finetune import calculate_rouge, use_task_specific_params
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
except ImportError:
from finetune import calculate_rouge, use_task_specific_params
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -22,8 +22,14 @@ def chunks(lst, n):
yield lst[i : i + n]
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
def generate_summaries_or_translations(
examples: list,
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
**gen_kwargs,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
@@ -39,11 +45,10 @@ def generate_summaries(
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to(
device
)
summaries = model.generate(**dct)
batch = tokenizer.batch_encode_plus(
batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True
).to(device)
summaries = model.generate(**batch, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
@@ -57,22 +62,26 @@ def run_generate():
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge")
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
generate_summaries(
generate_summaries_or_translations(
examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16
)
if args.score_path is not None:
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
scores = {}
if args.reference_path is not None:
score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
rouge: dict = calculate_rouge(output_lns, reference_lns)
json.dump(rouge, open("score_path", "w+"))
scores: dict = score_fn(output_lns, reference_lns)
if args.score_path is not None:
json.dump(scores, open("score_path", "w+"))
return scores
if __name__ == "__main__":

View File

@@ -0,0 +1,252 @@
import argparse
import logging
import os
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .run_eval import generate_summaries_or_translations, run_generate
from .utils import SummarizationDataset, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"logger": "default",
"length_penalty": 0.5,
"cache_dir": "",
"task": "summarization",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": CUDA_AVAILABLE,
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"do_lower_case": False,
"learning_rate": 0.3,
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_loss_encoder": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(**kwargs):
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
return tmp_dir
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
def test_multigpu(self):
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
self._test_distiller_cli(updates)
def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates)
def test_distill_checkpointing_with_teacher(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
out_path = tempfile.mktemp()
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher=T5_TINY,
model_name_or_path=T5_TINY,
tokenizer_name=T5_TINY,
)
self._test_distiller_cli(updates)
def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents = {os.path.basename(p) for p in contents}
self.assertIn(ckpt_name, contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = load_json(model.metrics_save_path)
last_step_stats = metrics["val"][-1]
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1)
return model
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@pytest.mark.parametrize(
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=model,
tokenizer_name=None,
train_batch_size=2,
eval_batch_size=2,
output_dir=output_dir,
do_predict=True,
task=task,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
@pytest.mark.parametrize(
["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)]
)
def test_dataset(tok):
tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir()
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
assert batch["attention_mask"].shape == batch["input_ids"].shape
# show that articles were trimmed.
assert batch["input_ids"].shape[1] == max_len_source
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
# show that targets were truncated
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
assert max_len_target > trunc_target # Truncated

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export BS=32
export GAS=1
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.25 \
--n_val 500 \
--num_train_epochs 2 \
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
--max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--data_dir $CNN_DIR \
--model_name_or_path sshleifer/student_cnn_12_6 \
--tokenizer_name facebook/bart-large \
--output_dir distilbart-cnn-12-6 \
$@

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
export BS=16
export GAS=2
python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.1 --n_val 1000 \
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--student_decoder_layers 6 --student_encoder_layers 12 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \
--alpha_hid=3. --length_penalty=0.5 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
--tokenizer_name facebook/bart-large \
--output_dir distilbart_xsum_12_6 \
$@

View File

@@ -3,12 +3,13 @@ import json
import os
import pickle
from pathlib import Path
from typing import Dict, Iterable, List
from typing import Callable, Dict, Iterable, List
import git
import numpy as np
import torch
from rouge_score import rouge_scorer, scoring
from sacrebleu import corpus_bleu
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
@@ -41,7 +42,7 @@ def encode_file(
examples = []
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
tokenized = tokenizer.batch_encode_plus(
[text], # DONT ADD SPACES
[text],
max_length=max_length,
pad_to_max_length=pad_to_max_length,
add_prefix_space=True,
@@ -54,11 +55,13 @@ def encode_file(
return examples
def lmap(f, x):
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
T5_PREFIX = "summarize: " # HACK, fixme
def calculate_bleu_score(output_lns, refs_lns) -> dict:
return {"bleu": corpus_bleu(output_lns, [refs_lns]).score}
def trim_batch(
@@ -95,6 +98,8 @@ class SummarizationDataset(Dataset):
tok_name=tok_name,
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
@@ -189,14 +194,20 @@ def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str):
"""
Log commit info.
"""
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4)
def save_json(content, path):
with open(path, "w") as f:
json.dump(content, f, indent=4)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():

View File

@@ -1,70 +0,0 @@
### Data
CNN/DailyMail data
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```
this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.
XSUM Data:
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```
### Evaluation
To create summaries for each article in dataset, run:
```bash
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Training
Run/modify `finetune.sh`
The following command should work on a 16GB GPU:
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir="$me"_xsum_results \
--num_train_epochs 1
```
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir "$me"_xsum_frozen_embs \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6
```
Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)

View File

@@ -1,267 +0,0 @@
import argparse
import logging
import os
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
import torch
from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .distillation import distill_main, evaluate_checkpoint
from .finetune import main
from .run_eval import generate_summaries, run_generate
from .utils import SummarizationDataset, lmap, pickle_load
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
FP16_EVER = False
CHEAP_ARGS = {
"logger": "default",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
"tgt_suffix": "",
"resume_from_checkpoint": None,
"sortish_sampler": True,
"student_decoder_layers": 1,
"val_check_interval": 1.0,
"output_dir": "",
"fp16": False,
"no_teacher": False,
"fp16_opt_level": "O1",
"gpus": 1 if torch.cuda.is_available() else 0,
"n_tpu_cores": 0,
"max_grad_norm": 1.0,
"do_train": True,
"do_predict": True,
"gradient_accumulation_steps": 1,
"server_ip": "",
"server_port": "",
"seed": 42,
"model_type": "bart",
"model_name_or_path": "sshleifer/bart-tiny-random",
"config_name": "",
"tokenizer_name": "facebook/bart-large",
"cache_dir": "",
"do_lower_case": False,
"learning_rate": 3e-05,
"weight_decay": 0.0,
"adam_epsilon": 1e-08,
"warmup_steps": 0,
"num_train_epochs": 1,
"train_batch_size": 2,
"eval_batch_size": 2,
"max_source_length": 12,
"max_target_length": 12,
"val_max_target_length": 12,
"test_max_target_length": 12,
"fast_dev_run": False,
"no_cache": False,
"n_train": -1,
"n_val": -1,
"n_test": -1,
"student_encoder_layers": 1,
"alpha_loss_encoder": 0.0,
"freeze_encoder": False,
"auto_scale_batch_size": False,
}
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
MSG = "T5 is broken at the moment"
T5_TINY = "patrickvonplaten/t5-tiny-random"
def make_test_data_dir():
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), articles)
_dump_articles((tmp_dir / f"{split}.target"), summaries)
return tmp_dir
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
@unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test")
def test_bdc_multigpu(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=False,
fp16_opt_level="O1",
fp16=FP16_EVER,
)
self._bart_distiller_cli(updates)
def test_bdc_t5_train(self):
updates = dict(
fp16=FP16_EVER,
gpus=1 if torch.cuda.is_available() else 0,
model_type="t5",
model_name_or_path=T5_TINY,
do_train=True,
do_predict=True,
tokenizer_name=T5_TINY,
no_teacher=True,
alpha_hid=2.0,
)
self._bart_distiller_cli(updates)
def test_bdc_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,)
self._bart_distiller_cli(updates)
def test_bdc_yes_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1,)
self._bart_distiller_cli(updates)
def test_bdc_checkpointing(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
)
model = self._bart_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), len(ckpts))
new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(new_transformer_ckpts), 1)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
out_path = tempfile.mktemp()
generate_summaries(examples, out_path, new_transformer_ckpts[0].parent)
self.assertTrue(Path(out_path).exists())
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def _bart_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
alpha_mlm=0.2,
alpha_ce=0.8,
do_predict=True,
gpus=1 if torch.cuda.is_available() else 0,
model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5,
alpha_encoder_loss=0.4,
)
default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d))
if not check_contents:
return model
contents = os.listdir(output_dir)
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
contents = {os.path.basename(p) for p in contents}
self.assertIn(ckpt_name, contents)
self.assertIn("metrics.pkl", contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("val_generations_00001.txt", contents)
self.assertIn("val_results_00001.txt", contents)
self.assertIn("test_results.txt", contents)
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
return model
class TestBartExamples(unittest.TestCase):
@classmethod
def setUpClass(cls):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
return cls
def test_bart_cnn_cli(self):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
os.remove(Path(output_file_name))
def test_t5_run_sum_cli(self):
args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_name_or_path=T5_TINY,
tokenizer_name=None, # T5_TINY,
train_batch_size=2,
eval_batch_size=2,
gpus=0,
output_dir=output_dir,
do_predict=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
_dump_articles((tmp_dir / "train.source"), articles)
_dump_articles((tmp_dir / "train.target"), summaries)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape)
# show that articles were trimmed.
self.assertEqual(batch["input_ids"].shape[1], max_len_source)
self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly
# show that targets were truncated
self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated
self.assertGreater(max_len_target, trunc_target) # Truncated
def list_to_text_file(lst, path):
dest = Path(path)
dest.open("w+").writelines(lst)

View File

@@ -1,51 +0,0 @@
***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points***
### Intro
This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be
evaluated on the WMT English-German dataset.
### Get the WMT Data
To be able to reproduce the authors' results on WMT English to German, you first need to download
the WMT14 en-de news datasets.
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via:
```bash
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de
```
You should have 2737 sentences in each file. You can verify this by running:
```bash
wc -l newstest2014.en # should give 2737
```
### Usage
Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters:
Get the longest and shortest sentence:
```bash
awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word
awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words
```
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`.
To create translation for each in dataset and get a final BLEU score, run:
```bash
python evaluate_wmt.py <path_to_newstest2014.en> newstest2014_de_translations.txt <path_to_newstest2014.de> newsstest2014_en_de_bleu.txt
```
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Where is the code?
The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples.
### BLEU Scores
The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost.
To get the BLEU score we used

View File

@@ -1,103 +0,0 @@
import argparse
from pathlib import Path
import torch
from sacrebleu import corpus_bleu
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, T5Tokenizer
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_translations(lns, output_file_path, model_size, batch_size, device):
model = T5ForConditionalGeneration.from_pretrained(model_size)
model.to(device)
tokenizer = T5Tokenizer.from_pretrained(model_size)
# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get("translation_en_to_de", {}))
with Path(output_file_path).open("w") as output_file:
for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)
input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device)
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations
]
for hypothesis in dec:
output_file.write(hypothesis + "\n")
def calculate_bleu_score(output_lns, refs_lns, score_path):
bleu = corpus_bleu(output_lns, [refs_lns])
result = "BLEU score: {}".format(bleu.score)
with Path(score_path).open("w") as score_file:
score_file.write(result)
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"model_size",
type=str,
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base",
)
parser.add_argument(
"input_path", type=str, help="like wmt/newstest2014.en",
)
parser.add_argument(
"output_path", type=str, help="where to save translation",
)
parser.add_argument(
"reference_path", type=str, help="like wmt/newstest2014.de",
)
parser.add_argument(
"score_path", type=str, help="where to save the bleu score",
)
parser.add_argument(
"--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time",
)
parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
dash_pattern = (" ##AT##-##AT## ", "-")
# Read input lines into python
with open(args.input_path, "r") as input_file:
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()]
generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device)
# Read generated lines into python
with open(args.output_path, "r") as output_file:
output_lns = [x.strip() for x in output_file.readlines()]
# Read reference lines into python
with open(args.reference_path, "r") as reference_file:
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()]
calculate_bleu_score(output_lns, refs_lns, args.score_path)
if __name__ == "__main__":
run_generate()

View File

@@ -1,50 +0,0 @@
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch
from .evaluate_wmt import run_generate
text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."]
translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."]
output_file_name = "output_t5_trans.txt"
score_file_name = "score_t5_trans.txt"
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo"
with tmp_source.open("w") as f:
f.write("\n".join(text))
tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target"
with tmp_target.open("w") as f:
f.write("\n".join(translation))
output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo"
score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo"
testargs = [
"evaluate_wmt.py",
"patrickvonplaten/t5-tiny-random",
str(tmp_source),
str(output_file_name),
str(tmp_target),
str(score_file_name),
]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
self.assertTrue(Path(score_file_name).exists())

View File

@@ -20,6 +20,7 @@ known_third_party =
pandas
PIL
psutil
pytest
pytorch_lightning
rouge_score
sacrebleu

View File

@@ -55,7 +55,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
}
_all_mbart_models = ["facebook/mbart-large-en-ro"]
_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"]
SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model"
@@ -105,6 +105,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
"vi_VN": 250024,
"zh_CN": 250025,
}
id_to_lang_code = {v: k for k, v in lang_code_to_id.items()}
cur_lang_code = lang_code_to_id["en_XX"]
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
@@ -115,6 +116,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + special_tokens
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.id_to_lang_code:
return self.id_to_lang_code[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def set_lang(self, lang: str) -> None:
"""Set the current language code in order to call batch_encode_plus properly."""
self.cur_lang_code = self.lang_code_to_id[lang]
def prepare_translation_batch(
self,
src_texts: List[str],