From 81ebd70671497aac2e87dda4b0a1c4abb9af7ecd Mon Sep 17 00:00:00 2001 From: Sumithra Bhakthavatsalam Date: Wed, 11 Nov 2020 14:58:45 -0800 Subject: [PATCH] [s2s] distill t5-large -> t5-small (#8376) Co-authored-by: Sam Shleifer --- examples/seq2seq/README.md | 2 +- examples/seq2seq/distillation.py | 157 +++++++++++++--------- examples/seq2seq/test_bash_script.py | 4 +- examples/seq2seq/test_seq2seq_examples.py | 12 ++ 4 files changed, 108 insertions(+), 67 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index 2802347e66..450fbb3636 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -380,7 +380,7 @@ cp xsum/test* all_pl then use `all_pl` as DATA in the command above. #### Direct Knowledge Distillation (KD) -+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`. ++ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `SummarizationDistiller`. + This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced. + You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you. diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index d6b6f43695..e76e13fafe 100755 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -25,8 +25,8 @@ sys.path.insert(2, str(Path(__file__).resolve().parents[1])) from lightning_base import generic_train # noqa -class BartSummarizationDistiller(SummarizationModule): - """Supports Bart, Pegasus and other models that inherit from Bart.""" +class SummarizationDistiller(SummarizationModule): + """Supports T5, Bart, Pegasus and other models that inherit from Bart.""" loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"] @@ -40,26 +40,38 @@ class BartSummarizationDistiller(SummarizationModule): hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default - student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers( - teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir - ) + if hparams.student is not None: + student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student) + use_task_specific_params(student, hparams.task) + e_layer_ids, d_layer_ids = None, None + else: + student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers( + teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir + ) + if hparams.length_penalty != -1: student.config.length_penalty = hparams.length_penalty hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer super().__init__(hparams, model=student, config=student.config) - model_type = student.config.model_type - self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] + assert ( + student.config.model_type == teacher.config.model_type + ), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}" - if model_type == "t5": + if student.config.model_type == "t5": + student_encoder_layers = len(student.get_encoder().block) + student_decoder_layers = len(student.get_decoder().block) teacher_encoder_layers = len(teacher.get_encoder().block) teacher_decoder_layers = len(teacher.get_decoder().block) else: + student_encoder_layers = student.config.encoder_layers + student_decoder_layers = student.config.decoder_layers teacher_encoder_layers = teacher.config.encoder_layers teacher_decoder_layers = teacher.config.decoder_layers - self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers - self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers - + self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student) + self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0 + self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers) + # self.different_encoder determines whether we need to run the teacher encoder self.teacher = teacher freeze_params(self.teacher) @@ -68,13 +80,28 @@ class BartSummarizationDistiller(SummarizationModule): del self.teacher.model.encoder except AttributeError: # T5 del self.teacher.encoder - # Intermediate supervision: Decide which layers to supervise - if hparams.supervise_forward: - self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers) - self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers) - else: # student layer should emulate hidden states of the teacher layer it was copied from - self.e_matches = self.e_layer_ids - self.d_matches = self.d_layer_ids + + if e_layer_ids is None: + e_layer_ids = list(range(student_encoder_layers)) + if d_layer_ids is None: + d_layer_ids = list(range(student_decoder_layers)) + + self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] + + if self.do_calc_hidden_loss: # Intermediate supervision: Decide which layers to supervise + if hparams.supervise_forward: + self.e_matches = get_layers_to_supervise( + n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers + ) + self.d_matches = get_layers_to_supervise( + n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers + ) + else: # student layer should emulate hidden states of the teacher layer it was copied from + self.e_matches = self.e_layer_ids + self.d_matches = self.d_layer_ids + else: + self.e_matches = None + self.d_matches = None self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.temperature = 2.0 @@ -84,22 +111,8 @@ class BartSummarizationDistiller(SummarizationModule): gc.collect() torch.cuda.empty_cache() - def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: - """Supervise MSE(teacher.encoder_outputs, student.encoder_outputs).""" - # raise NotImplementedError() - if mask is not None: - # mask has False at padding_idx - sel_mask = mask[:, :, None].expand_as(student_outputs).bool() - s_logits_slct = torch.masked_select(student_outputs, sel_mask) - t_logits_slct = torch.masked_select(teacher_outputs, sel_mask) - else: - t_logits_slct = teacher_outputs - s_logits_slct = student_outputs - return F.mse_loss(s_logits_slct, t_logits_slct) - def calc_ce_loss(self, mask, s_logits, t_logits): """Copy pasted from distillbert (transformers/examples/distillation/)""" - # mask has False at padding_idx sel_mask = mask[:, :, None].expand_as(s_logits) vocab_size = s_logits.size(-1) @@ -123,8 +136,8 @@ class BartSummarizationDistiller(SummarizationModule): add_distill_args(parser) return parser - def _step(self, batch): - # assert is_frozen(self.teacher) copied_decoder_layers + def _step(self, batch: dict) -> tuple: + """Compute the loss for a batch""" pad_token_id = self.tokenizer.pad_token_id input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"] if isinstance(self.model, T5ForConditionalGeneration): @@ -133,14 +146,16 @@ class BartSummarizationDistiller(SummarizationModule): decoder_input_ids = shift_tokens_right(labels, pad_token_id) # noinspection PyCallingNonCallable - lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( + student_outputs = self( input_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, - output_hidden_states=True, + output_hidden_states=self.do_calc_hidden_loss, output_attentions=False, use_cache=False, + return_dict=True, ) + lm_logits = student_outputs.logits # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size @@ -149,7 +164,7 @@ class BartSummarizationDistiller(SummarizationModule): loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) else: - lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) + lprobs = F.log_softmax(lm_logits, dim=-1) student_lm_loss, _ = label_smoothed_nll_loss( lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id ) @@ -157,37 +172,44 @@ class BartSummarizationDistiller(SummarizationModule): def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) + teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss - with torch.no_grad(): - teacher_enc_hid = self.teacher.get_encoder()( - input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True - ).hidden_states - - hid_loss_enc = self.calc_hidden_loss( - src_mask, - enc_hidden_state, - teacher_enc_hid, - self.e_matches, - normalize_hidden=self.hparams.normalize_hidden, - ) - - with torch.no_grad(): - outputs = self.teacher( + all_teacher_encoder_outputs = self.teacher.get_encoder()( input_ids, attention_mask=src_mask, - encoder_outputs=(enc_outputs,), - decoder_input_ids=decoder_input_ids, - lm_labels=labels, - output_hidden_states=True, + output_hidden_states=self.do_calc_hidden_loss, return_dict=True, ) - tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states + if self.different_base_models: + teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state + elif self.do_calc_hidden_loss: + hid_loss_enc = self.calc_hidden_loss( + src_mask, + student_outputs.encoder_hidden_states, + all_teacher_encoder_outputs.hidden_states, + self.e_matches, + normalize_hidden=self.hparams.normalize_hidden, + ) + + teacher_outputs = self.teacher( + input_ids, + attention_mask=src_mask, + encoder_outputs=(teacher_enc_outputs,), + decoder_input_ids=decoder_input_ids, + output_hidden_states=self.do_calc_hidden_loss, + use_cache=False, # since we are not passing labels, never let this default to True + return_dict=True, + ) dec_mask = decoder_input_ids.ne(pad_token_id) - loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) - if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states + loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits) + if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( - dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden + dec_mask, + student_outputs.decoder_hidden_states, + teacher_outputs.decoder_hidden_states, + self.d_matches, + normalize_hidden=self.hparams.normalize_hidden, ) blended_loss = ( @@ -207,6 +229,7 @@ class BartSummarizationDistiller(SummarizationModule): valid_count = mask.sum() * hidden_states[0].size(-1) student_states = torch.stack([hidden_states[i] for i in range(len(matches))]) teacher_states = torch.stack([hidden_states_T[j] for j in matches]) + assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}" if normalize_hidden: student_states = F.layer_norm(student_states, student_states.shape[1:]) teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) @@ -216,10 +239,16 @@ class BartSummarizationDistiller(SummarizationModule): def add_distill_args(parser): + # NOTE: if --student argument was specified and the teacher and student base models + # are different, the models still have to have the same tokenizer, specified by + # --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not + # from bart to t5. This s because if the tokenizers are different, the output space + # for the two models is also different and their logits are not comparable. parser.add_argument("--teacher", type=str) parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) + parser.add_argument("--student", type=str, required=False) parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) parser.add_argument("--no_teacher", action="store_true", default=False) @@ -228,8 +257,8 @@ def add_distill_args(parser): parser.add_argument("--normalize_hidden", action="store_true", default=False) -class BartTranslationDistiller(BartSummarizationDistiller): - """Supports Mbart, Marian, other models that inherit from Bart.""" +class TranslationDistiller(SummarizationDistiller): + """Supports T5, mBART, Marian, other models that inherit from Bart.""" mode = "translation" metric_names = ["bleu"] @@ -258,7 +287,7 @@ def create_module(args): if args.no_teacher: module_cls = TranslationModule if "translation" in args.task else SummarizationModule else: # DISTILL WITH TEACHER - module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller + module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller args.setup_cls: str = module_cls.__name__ print(f"using module {args.setup_cls}") model = module_cls(args) @@ -276,7 +305,7 @@ def distill_main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) - parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() distill_main(args) diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index bf4302fd9f..9e4ffc2217 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -9,7 +9,7 @@ import pytorch_lightning as pl import timeout_decorator import torch -from distillation import BartSummarizationDistiller, distill_main +from distillation import SummarizationDistiller, distill_main from finetune import SummarizationModule, main from transformers import MarianMTModel from transformers.file_utils import cached_path @@ -170,7 +170,7 @@ class TestDistilMarianNoTeacher(TestCasePlus): with patch.object(sys, "argv", testargs): parser = argparse.ArgumentParser() parser = pl.Trainer.add_argparse_args(parser) - parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() # assert args.gpus == gpus THIS BREAKS for multi_gpu diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index a2b306032e..2fee837fe9 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -96,6 +96,7 @@ CHEAP_ARGS = { "freeze_encoder": False, "auto_scale_batch_size": False, "overwrite_output_dir": False, + "student": None, } @@ -107,6 +108,7 @@ def _dump_articles(path: Path, articles: list): ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."] SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] T5_TINY = "patrickvonplaten/t5-tiny-random" +T5_TINIER = "sshleifer/t5-tinier-random" BART_TINY = "sshleifer/bart-tiny-random" MBART_TINY = "sshleifer/tiny-mbart" MARIAN_TINY = "sshleifer/tiny-marian-en-de" @@ -239,6 +241,16 @@ class TestSummarizationDistiller(TestCasePlus): ) self._test_distiller_cli(updates) + @require_torch_non_multi_gpu_but_fix_me + def test_distill_different_base_models(self): + updates = dict( + teacher=T5_TINY, + student=T5_TINIER, + model_name_or_path=T5_TINIER, + tokenizer_name=T5_TINIER, + ) + self._test_distiller_cli(updates) + def _test_distiller_cli(self, updates, check_contents=True): default_updates = dict( label_smoothing=0.0,