Add mbart-large-cc25, support translation finetuning (#5129)
improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import faiss
|
||||
import nlp
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
import streamlit as st
|
||||
import transformers
|
||||
from eli5_utils import (
|
||||
embed_questions_for_retrieval,
|
||||
|
||||
@@ -41,6 +41,28 @@ If you are using your own data, it must be formatted as one directory with 6 fil
|
||||
The `.source` files are the input, the `.target` files are the desired output.
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
|
||||
General Tips:
|
||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, the easiest workflow is fork transformers, clone your fork, and run `pip install -e .` before you get started.
|
||||
- try `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr per epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
- Read scripts before you run them!
|
||||
|
||||
Summarization Tips:
|
||||
- (summ) 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
|
||||
### Summarization Finetuning
|
||||
Run/modify `finetune.sh`
|
||||
@@ -58,25 +80,20 @@ The following command should work on a 16GB GPU:
|
||||
|
||||
*Note*: The following tips mostly apply to summarization finetuning.
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
|
||||
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
|
||||
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
|
||||
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.
|
||||
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
|
||||
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
|
||||
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
|
||||
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
|
||||
- This warning can be safely ignored:
|
||||
> "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']"
|
||||
- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start).
|
||||
### Translation Finetuning
|
||||
|
||||
#### Finetuning Outputs
|
||||
First, follow the wmt_en_ro download instructions.
|
||||
Then you can finetune mbart_cc25 on english-romanian with the following command.
|
||||
**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it.
|
||||
```bash
|
||||
export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded
|
||||
export BS=4
|
||||
export GAS=8
|
||||
./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/
|
||||
```
|
||||
|
||||
|
||||
### Finetuning Outputs
|
||||
As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine).
|
||||
Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour:
|
||||
|
||||
|
||||
@@ -14,11 +14,12 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
@@ -47,6 +48,7 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
assert_all_frozen,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
@@ -92,9 +94,12 @@ class SummarizationModule(BaseTransformer):
|
||||
if self.hparams.freeze_embeds:
|
||||
self.freeze_embeds()
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
freeze_params(self.model.get_encoder())
|
||||
assert_all_frozen(self.model.get_encoder())
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@@ -160,7 +165,12 @@ class SummarizationModule(BaseTransformer):
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
attention_mask=source_mask,
|
||||
use_cache=True,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
)
|
||||
gen_time = (time.time() - t0) / source_ids.shape[0]
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
@@ -276,6 +286,9 @@ class SummarizationModule(BaseTransformer):
|
||||
parser.add_argument(
|
||||
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
||||
)
|
||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -285,6 +298,13 @@ class TranslationModule(SummarizationModule):
|
||||
metric_names = ["bleu"]
|
||||
val_metric = "bleu"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
export OUTPUT_DIR_NAME=t5
|
||||
export CURRENT_DIR=${PWD}
|
||||
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
|
||||
|
||||
# Make output directory if it doesn't exist
|
||||
mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_name_or_path=t5-large \
|
||||
--data_dir=$CNN_DIR \
|
||||
--learning_rate=3e-5 \
|
||||
--train_batch_size=4 \
|
||||
--eval_batch_size=4 \
|
||||
--train_batch_size=$BS \
|
||||
--eval_batch_size=$BS \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_train $@
|
||||
--max_source_length=512 \
|
||||
--val_check_interval=0.1 --n_val=200 \
|
||||
--do_train --do_predict \
|
||||
$@
|
||||
|
||||
@@ -223,10 +223,30 @@ def test_finetune(model):
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
task=task,
|
||||
src_lang="en_XX",
|
||||
tgt_lang="ro_RO",
|
||||
freeze_encoder=True,
|
||||
freeze_embeds=True,
|
||||
)
|
||||
assert "n_train" in args_d
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
module = main(args)
|
||||
|
||||
input_embeds = module.model.get_input_embeddings()
|
||||
assert not input_embeds.weight.requires_grad
|
||||
if model == T5_TINY:
|
||||
lm_head = module.model.lm_head
|
||||
assert not lm_head.weight.requires_grad
|
||||
assert (lm_head.weight == input_embeds.weight).all().item()
|
||||
|
||||
else:
|
||||
bart = module.model.model
|
||||
embed_pos = bart.decoder.embed_positions
|
||||
assert not embed_pos.weight.requires_grad
|
||||
assert not bart.shared.weight.requires_grad
|
||||
# check that embeds are the same
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -239,7 +259,12 @@ def test_dataset(tok):
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = SummarizationDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
tgt_lang="ro_RO",
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
|
||||
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
21
examples/seq2seq/train_mbart_cc25_enro.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python finetune.py \
|
||||
--learning_rate=3e-5 \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--val_check_interval 0.1 \
|
||||
--n_val 500 \
|
||||
--adam_eps 1e-06 \
|
||||
--num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \
|
||||
--max_source_length=300 --max_target_length 300 --val_max_target_length=300 --test_max_target_length 300 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--model_name_or_path facebook/mbart-large-cc25 \
|
||||
--task translation \
|
||||
--warmup_steps 500 \
|
||||
--logger wandb --sortish_sampler \
|
||||
$@
|
||||
@@ -14,6 +14,8 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
@@ -25,6 +27,7 @@ def encode_file(
|
||||
prefix="",
|
||||
tok_name="",
|
||||
):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
||||
if not overwrite_cache and cache_path.exists():
|
||||
try:
|
||||
@@ -46,8 +49,8 @@ def encode_file(
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
add_prefix_space=True,
|
||||
return_tensors=return_tensors,
|
||||
**extra_kw,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
@@ -87,9 +90,14 @@ class SummarizationDataset(Dataset):
|
||||
n_obs=None,
|
||||
overwrite_cache=False,
|
||||
prefix="",
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
):
|
||||
super().__init__()
|
||||
# FIXME: the rstrip logic strips all the chars, it seems.
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
||||
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
@@ -100,7 +108,8 @@ class SummarizationDataset(Dataset):
|
||||
)
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
if hasattr(tokenizer, "set_lang"):
|
||||
tokenizer.set_lang("ro_RO") # HACK: only applies to mbart
|
||||
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
||||
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
@@ -224,8 +233,8 @@ def get_git_info():
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
||||
|
||||
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=True)
|
||||
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
||||
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
|
||||
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
||||
|
||||
Reference in New Issue
Block a user