From d86d57faa3b6511c6e4d9139535d77b695b9af8a Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 18 Nov 2020 12:51:29 -0800 Subject: [PATCH] [s2s] distillation apex breaks return_dict obj (#8631) * apex breaks return_dict obj * style --- examples/seq2seq/distillation.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 9f2b9b713c..3b3bd80589 100755 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule): output_attentions=False, use_cache=False, ) - lm_logits = student_outputs.logits + lm_logits = student_outputs["logits"] # Same cross entropy vs. label smoothing logic as finetune.py assert lm_logits.shape[-1] == self.model.config.vocab_size @@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule): def zero_tensor(): return torch.tensor(0.0).type_as(student_lm_loss) - teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models + teacher_enc_outputs = student_outputs[ + "encoder_last_hidden_state" + ] # use this unless self.different_base_models hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor() if self.different_encoder: # compute encoder hidden state loss all_teacher_encoder_outputs = self.teacher.get_encoder()( @@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule): output_hidden_states=self.do_calc_hidden_loss, ) if self.different_base_models: - teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state + teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"] elif self.do_calc_hidden_loss: hid_loss_enc = self.calc_hidden_loss( src_mask, - student_outputs.encoder_hidden_states, - all_teacher_encoder_outputs.hidden_states, + student_outputs["encoder_hidden_states"], + all_teacher_encoder_outputs["hidden_states"], self.e_matches, normalize_hidden=self.hparams.normalize_hidden, ) @@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule): use_cache=False, # since we are not passing labels, never let this default to True ) dec_mask = decoder_input_ids.ne(pad_token_id) - loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits) + loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"]) if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states hid_loss_dec = self.calc_hidden_loss( dec_mask, - student_outputs.decoder_hidden_states, - teacher_outputs.decoder_hidden_states, + student_outputs["decoder_hidden_states"], + teacher_outputs["decoder_hidden_states"], self.d_matches, normalize_hidden=self.hparams.normalize_hidden, )