From 5b193b39b09b30e654d8328f418f632ee618098c Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 21 Jul 2020 16:51:39 -0400 Subject: [PATCH] [examples/seq2seq]: add --label_smoothing option (#5919) --- examples/seq2seq/README.md | 39 +++++++++++++--- examples/seq2seq/callbacks.py | 6 ++- examples/seq2seq/finetune.py | 45 +++++++++++++++---- examples/seq2seq/test_seq2seq_examples.py | 31 +++++++++++-- examples/seq2seq/train_mbart_cc25_enro.sh | 12 ++--- .../seq2seq/train_mbart_cc25_enro_multigpu.sh | 18 -------- examples/seq2seq/utils.py | 27 ++++++++++- 7 files changed, 132 insertions(+), 46 deletions(-) delete mode 100755 examples/seq2seq/train_mbart_cc25_enro_multigpu.sh diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index d0384bf274..f72d9dc7b7 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -27,8 +27,18 @@ this should make a directory called `cnn_dm/` with files like `test.source`. ``` WMT16 English-Romanian Translation Data: + +This dataset comes in two formats. The "packed" version merges short training examples into examples of <200 tokens to increase GPU utilization (and also improves validation performance). + ```bash cd examples/seq2seq +https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro_packed_train_200.tgz +tar -xzvf wmt_en_ro_packed_200.tgz +export ENRO_DIR=wmt_en_ro_packed_train_200 +``` + +The original data can also be downloaded with this command: +```bash wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz tar -xzvf wmt_en_ro.tar.gz export ENRO_DIR=${PWD}/wmt_en_ro @@ -84,16 +94,31 @@ The following command should work on a 16GB GPU: First, follow the wmt_en_ro download instructions. Then you can finetune mbart_cc25 on english-romanian with the following command. -**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it. +**Recommendation:** Read and potentially modify the fairly opinionated defaults in `train_mbart_cc25_enro.sh` script before running it. + +Best performing command: ```bash -export ENRO_DIR=${PWD}/wmt_en_ro # may need to be fixed depending on where you downloaded -export MAX_LEN=128 +# optionally +export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above +# export WANDB_PROJECT="MT" # optional +export MAX_LEN=200 export BS=4 -export GAS=8 -./train_mbart_cc25_enro.sh --output_dir cc25_v1_frozen/ +export GAS=8 # gradient accumulation steps +./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --label_smoothing 0.1 --fp16_opt_level=O1 --logger_name wandb --sortish_sampler ``` +This should take < 2h/epoch on a 16GB v100 and achieve val_avg_ BLEU score above 25. (you can see in wandb or metrics.json). +To get results in line with fairseq, you need to do some postprocessing. - +MultiGPU command +(using 8 GPUS as an example) +```bash +export ENRO_DIR='wmt_en_ro_packed_train_200' # Download instructions above + # export WANDB_PROJECT="MT" # optional +export MAX_LEN=200 +export BS=4 +export GAS=1 # gradient accumulation steps +./train_mbart_cc25_enro.sh --output_dir enro_finetune_baseline --gpus 8 --logger_name wandb +``` ### Finetuning Outputs As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine). Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour: @@ -108,7 +133,7 @@ output_dir │   ├── tokenizer_config.json │   └── vocab.json ├── git_log.json # repo, branch, and commit hash -├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. +├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. (it will be called BLEU for MT) ├── metrics.json # new validation metrics will continually be appended to this ├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned. │   ├── config.json diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index 523d18fcd7..1de3aa5d46 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -5,7 +5,7 @@ from pathlib import Path import numpy as np import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities import rank_zero_only @@ -90,3 +90,7 @@ def get_checkpoint_callback(output_dir, metric): period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback + + +def get_early_stopping_callback(metric, patience): + return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 1afeafd042..88b414860f 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -33,9 +33,10 @@ try: calculate_bleu_score, Seq2SeqDataset, MBartDataset, + label_smoothed_nll_loss, ) - from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback + from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback except ImportError: from utils import ( Seq2SeqDataset, @@ -52,8 +53,9 @@ except ImportError: get_git_info, ROUGE_KEYS, calculate_bleu_score, + label_smoothed_nll_loss, ) - from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback + from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback logger = logging.getLogger(__name__) @@ -128,12 +130,22 @@ class SummarizationModule(BaseTransformer): def _step(self, batch: dict) -> Tuple: pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] - y_ids = y[:, :-1].contiguous() - lm_labels = y[:, 1:].clone() - lm_labels[y[:, 1:] == pad_token_id] = -100 - outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,) - loss = outputs[0] + source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] + decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? + lm_labels = target_ids[:, 1:].clone() # why clone? + outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) + + if self.hparams.label_smoothing == 0: + # Same behavior as modeling_bart.py + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) + lm_logits = outputs[0] + assert lm_logits.shape[-1] == self.model.config.vocab_size + loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) + else: + lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id + ) return (loss,) def training_step(self, batch, batch_idx) -> Dict: @@ -290,8 +302,16 @@ class SummarizationModule(BaseTransformer): parser.add_argument( "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." ) + parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) + parser.add_argument( + "--early_stopping_patience", + type=int, + default=-1, + required=False, + help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", + ) return parser @@ -335,17 +355,24 @@ def main(args, model=None) -> SummarizationModule: elif args.logger_name == "wandb": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name, project=dataset) + project = os.environ.get("WANDB_PROJECT", dataset) + logger = WandbLogger(name=model.output_dir.name, project=project) elif args.logger_name == "wandb_shared": from pytorch_lightning.loggers import WandbLogger logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") + + if args.early_stopping_patience >= 0: + es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) + else: + es_callback = False trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), + early_stopping_callback=es_callback, logger=logger, # TODO: early stopping callback seems messed up ) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 6e9117228a..e25fb0b0e7 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -12,14 +12,14 @@ import torch from pytest import param from torch.utils.data import DataLoader -from transformers import AutoTokenizer, MBartTokenizer +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, MBartTokenizer from transformers.testing_utils import require_multigpu from .distillation import distill_main, evaluate_checkpoint from .finetune import main from .pack_dataset import pack_data_dir from .run_eval import generate_summaries_or_translations, run_generate -from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json +from .utils import MBartDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json logging.basicConfig(level=logging.DEBUG) @@ -27,7 +27,8 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { - "label_smoothing_eps": 0.2, + "label_smoothing": 0.2, + "early_stopping_patience": 2, "logger_name": "default", "length_penalty": 0.5, "cache_dir": "", @@ -142,6 +143,26 @@ class TestSummarizationDistiller(unittest.TestCase): evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) + def test_loss_fn(self): + model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY) + input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"] + target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device) + decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? + lm_labels = target_ids[:, 1:].clone() # why clone? + model_computed_loss = model( + input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, labels=lm_labels, use_cache=False + ).loss + + logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits + + lprobs = torch.nn.functional.log_softmax(logits, dim=-1) + smoothed_loss, nll_loss = label_smoothed_nll_loss( + lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id + ) + with self.assertRaises(AssertionError): + # TODO: understand why this breaks + self.assertEqual(nll_loss, model_computed_loss) + @unittest.skip("T5 distillation is broken at the moment") def test_distill_t5(self): updates = dict( @@ -156,6 +177,8 @@ class TestSummarizationDistiller(unittest.TestCase): def _test_distiller_cli(self, updates, check_contents=True): default_updates = dict( + label_smoothing_eps=0.0, + early_stopping_patience=-1, train_batch_size=1, eval_batch_size=2, max_epochs=2, @@ -215,6 +238,8 @@ def test_run_eval_bart(model): def test_finetune(model): args_d: dict = CHEAP_ARGS.copy() task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization" + args_d["label_smoothing"] = 0.1 if task == "translation" else 0 + tmp_dir = make_test_data_dir() output_dir = tempfile.mkdtemp(prefix="output_") args_d.update( diff --git a/examples/seq2seq/train_mbart_cc25_enro.sh b/examples/seq2seq/train_mbart_cc25_enro.sh index bb65bd8fb4..f646f340b8 100755 --- a/examples/seq2seq/train_mbart_cc25_enro.sh +++ b/examples/seq2seq/train_mbart_cc25_enro.sh @@ -4,17 +4,17 @@ export PYTHONPATH="../":"${PYTHONPATH}" python finetune.py \ --learning_rate=3e-5 \ --fp16 \ - --gpus 1 \ --do_train \ --do_predict \ - --val_check_interval 0.1 \ + --val_check_interval 0.25 \ --adam_eps 1e-06 \ - --num_train_epochs 3 --src_lang en_XX --tgt_lang ro_RO \ - --freeze_encoder --freeze_embeds --data_dir $ENRO_DIR \ + --num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \ + --data_dir $ENRO_DIR \ --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ - --model_name_or_path facebook/mbart-large-cc25 \ --task translation \ --warmup_steps 500 \ - --logger_name wandb --sortish_sampler \ + --freeze_embeds \ + --early_stopping_patience 4 \ + --model_name_or_path facebook/mbart-large-cc25 \ $@ diff --git a/examples/seq2seq/train_mbart_cc25_enro_multigpu.sh b/examples/seq2seq/train_mbart_cc25_enro_multigpu.sh deleted file mode 100755 index 4368089d6d..0000000000 --- a/examples/seq2seq/train_mbart_cc25_enro_multigpu.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env bash -export PYTHONPATH="../":"${PYTHONPATH}" -# Need to export N_GPUS= -python finetune.py \ - --learning_rate=3e-5 \ - --fp16 \ - --gpus $N_GPUS \ - --do_train \ - --val_check_interval 0.25 \ - --adam_eps 1e-06 \ - --num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \ - --data_dir $ENRO_DIR \ - --max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \ - --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ - --tokenizer facebook/mbart-large-cc25 \ - --task translation \ - --warmup_steps 500 --freeze_encoder --freeze_embeds \ - $@ diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 38f04adee3..c2c484735f 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -19,6 +19,29 @@ from torch.utils.data import Dataset, Sampler from transformers import BartTokenizer +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): + """From fairseq""" + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + bs = pad_mask.long().sum() + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + bs = lprobs.shape[0] + + nll_loss = nll_loss.sum() # mean()? Scared to break other math. + smooth_loss = smooth_loss.sum() + eps_i = epsilon / lprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss / bs, nll_loss / bs + + def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} return tokenizer( @@ -144,8 +167,8 @@ class MBartDataset(Seq2SeqDataset): assert source_line, f"empty source line for index {index}" assert tgt_line, f"empty tgt line for index {index}" return { - "tgt_texts": source_line, - "src_texts": tgt_line, + "tgt_texts": tgt_line, + "src_texts": source_line, } def collate_fn(self, batch) -> Dict[str, torch.Tensor]: