diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 1e8ff40306..dc44a7322d 100755 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa class BartSummarizationDistiller(SummarizationModule): """Supports Bart, Pegasus and other models that inherit from Bart.""" - loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] + loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"] def __init__(self, hparams): assert Path(hparams.data_dir).exists() @@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule): if hparams.length_penalty != -1: student.config.length_penalty = hparams.length_penalty super().__init__(hparams, model=student, config=student.config) + model_type = student.config.model_type self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] - self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers - self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers + + if model_type == "t5": + teacher_encoder_layers = len(teacher.get_encoder().block) + teacher_decoder_layers = len(teacher.get_decoder().block) + else: + teacher_encoder_layers = teacher.config.encoder_layers + teacher_decoder_layers = teacher.config.decoder_layers + + self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers + self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers + self.teacher = teacher freeze_params(self.teacher) @@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule): del self.teacher.encoder # Intermediate supervision: Decide which layers to supervise if hparams.supervise_forward: - self.d_matches = get_layers_to_supervise( - n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers - ) - else: + self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers) + self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers) + else: # student layer should emulate hidden states of the teacher layer it was copied from + self.e_matches = self.e_layer_ids self.d_matches = self.d_layer_ids + self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.temperature = 2.0 self.alpha_mlm = hparams.alpha_mlm self.alpha_ce = hparams.alpha_ce self.alpha_hid = hparams.alpha_hid - self.alpha_encoder_loss = hparams.alpha_encoder_loss gc.collect() torch.cuda.empty_cache() @@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule): output_hidden_states=True, output_attentions=False, use_cache=False, - ) # TODO(@sshleifer): return_dict=True cleanup + ) # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size @@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule): def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) - loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() - if self.different_encoder: + hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() + if self.different_encoder: # compute encoder hidden state loss with torch.no_grad(): - teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()( - input_ids, attention_mask=src_mask, output_hidden_states=True - ) - # DEPRECATE THIS - if self.hparams.alpha_encoder_loss > 0: - loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) + teacher_enc_hid = self.teacher.get_encoder()( + input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True + ).hidden_states - hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids) - - teacher_enc_outputs = (enc_outputs,) - assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) + hid_loss_enc = self.calc_hidden_loss( + src_mask, + enc_hidden_state, + teacher_enc_hid, + self.e_matches, + normalize_hidden=self.hparams.normalize_hidden, + ) with torch.no_grad(): - tloss, tlogits, tdec_hidden, _ = self.teacher( + outputs = self.teacher( input_ids, attention_mask=src_mask, - encoder_outputs=teacher_enc_outputs, + encoder_outputs=(enc_outputs,), decoder_input_ids=decoder_input_ids, lm_labels=labels, output_hidden_states=True, + return_dict=True, ) + tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states dec_mask = decoder_input_ids.ne(pad_token_id) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states @@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule): blended_loss = ( self.alpha_ce * loss_ce + self.alpha_mlm * student_lm_loss - + self.hparams.alpha_encoder_loss * loss_encoder + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) ) - return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec + return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec @staticmethod def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden): @@ -207,7 +218,6 @@ def add_distill_args(parser): parser.add_argument("--teacher", type=str) parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float) - parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 3e054649bc..1b6c505c94 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -86,7 +86,6 @@ CHEAP_ARGS = { "n_val": -1, "n_test": -1, "student_encoder_layers": 1, - "alpha_encoder_loss": 0.0, "freeze_encoder": False, "auto_scale_batch_size": False, } @@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase): evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) - @unittest.skip("T5 distillation is broken at the moment") def test_distill_t5(self): updates = dict( student_encoder_layers=1, @@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase): model_name_or_path="sshleifer/tinier_bart", teacher=CHEAP_ARGS["model_name_or_path"], val_check_interval=0.5, - alpha_encoder_loss=0.4, ) default_updates.update(updates) args_d: dict = CHEAP_ARGS.copy()