Add mbart-large-cc25, support translation finetuning (#5129)

improve unittests for finetuning, especially w.r.t testing frozen parameters
fix freeze_embeds for T5
add streamlit setup.cfg
This commit is contained in:
Sam Shleifer
2020-07-07 13:23:01 -04:00
committed by GitHub
parent 141492448b
commit 353b8f1e7a
14 changed files with 521 additions and 204 deletions

View File

@@ -1,10 +1,10 @@
import faiss
import nlp
import numpy as np
import streamlit as st
import torch
from elasticsearch import Elasticsearch
import streamlit as st
import transformers
from eli5_utils import (
embed_questions_for_retrieval,

View File

@@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil
The `.source` files are the input, the `.target` files are the desired output.
### Tips and Tricks
General Tips:
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
- Read scripts before you run them!
Summarization Tips:
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
### Summarization Finetuning
Run/modify `finetune.sh`
@@ -58,25 +80,20 @@ The following command should work on a 16GB GPU:
*Note*: The following tips mostly apply to summarization finetuning.
Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- This warning can be safely ignored:
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
### Translation Finetuning
#### Finetuning Outputs
First, follow the wmt_en_ro download instructions.
Then you can finetune mbart_cc25 on english-romanian with the following command.
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
```bash
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
export BS=4
export GAS=8
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
```
### Finetuning Outputs
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:

View File

@@ -14,11 +14,12 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
@@ -47,6 +48,7 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
@@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
@@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser
@@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
metric_names = ["bleu"]
val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)

View File

@@ -1,18 +1,13 @@
export OUTPUT_DIR_NAME=t5
export CURRENT_DIR=${PWD}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
# Make output directory if it doesn't exist
mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--data_dir=./cnn-dailymail/cnn_dm \
--model_name_or_path=t5-large \
--data_dir=$CNN_DIR \
--learning_rate=3e-5 \
--train_batch_size=4 \
--eval_batch_size=4 \
--train_batch_size=$BS \
--eval_batch_size=$BS \
--output_dir=$OUTPUT_DIR \
--do_train $@
--max_source_length=512 \
--val_check_interval=0.1 --n_val=200 \
--do_train --do_predict \
$@

View File

@@ -223,10 +223,30 @@ def test_finetune(model):
output_dir=output_dir,
do_predict=True,
task=task,
src_lang="en_XX",
tgt_lang="ro_RO",
freeze_encoder=True,
freeze_embeds=True,
)
assert "n_train" in args_d
args = argparse.Namespace(**args_d)
main(args)
module = main(args)
input_embeds = module.model.get_input_embeddings()
assert not input_embeds.weight.requires_grad
if model == T5_TINY:
lm_head = module.model.lm_head
assert not lm_head.weight.requires_grad
assert (lm_head.weight == input_embeds.weight).all().item()
else:
bart = module.model.model
embed_pos = bart.decoder.embed_positions
assert not embed_pos.weight.requires_grad
assert not bart.shared.weight.requires_grad
# check that embeds are the same
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
@pytest.mark.parametrize(
@@ -239,7 +259,12 @@ def test_dataset(tok):
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
tokenizer,
data_dir=tmp_dir,
type_path="train",
max_source_length=20,
max_target_length=trunc_target,
tgt_lang="ro_RO",
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 1 \
--do_train \
--do_predict \
--val_check_interval 0.1 \
--n_val 500 \
--adam_eps 1e-06 \
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
--max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--model_name_or_path facebook/mbart-large-cc25 \
--task translation \
--warmup_steps 500 \
--logger wandb --sortish_sampler \
$@

View File

@@ -14,6 +14,8 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
@@ -25,6 +27,7 @@ def encode_file(
prefix="",
tok_name="",
):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
if not overwrite_cache and cache_path.exists():
try:
@@ -46,8 +49,8 @@ def encode_file(
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
add_prefix_space=True,
return_tensors=return_tensors,
**extra_kw,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
@@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
n_obs=None,
overwrite_cache=False,
prefix="",
src_lang=None,
tgt_lang=None,
):
super().__init__()
# FIXME: the rstrip logic strips all the chars, it seems.
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
if hasattr(tokenizer, "set_lang") and src_lang is not None:
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
@@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
)
tgt_path = os.path.join(data_dir, type_path + ".target")
if hasattr(tokenizer, "set_lang"):
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
@@ -224,8 +233,8 @@ def get_git_info():
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
aggregator = scoring.BootstrapAggregator()
for reference_ln, output_ln in zip(reference_lns, output_lns):