[s2s] distillation apex breaks return_dict obj (#8631)
* apex breaks return_dict obj * style
This commit is contained in:
@@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
lm_logits = student_outputs.logits
|
lm_logits = student_outputs["logits"]
|
||||||
|
|
||||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
@@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
def zero_tensor():
|
def zero_tensor():
|
||||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||||
|
|
||||||
teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
|
teacher_enc_outputs = student_outputs[
|
||||||
|
"encoder_last_hidden_state"
|
||||||
|
] # use this unless self.different_base_models
|
||||||
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||||
if self.different_encoder: # compute encoder hidden state loss
|
if self.different_encoder: # compute encoder hidden state loss
|
||||||
all_teacher_encoder_outputs = self.teacher.get_encoder()(
|
all_teacher_encoder_outputs = self.teacher.get_encoder()(
|
||||||
@@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
output_hidden_states=self.do_calc_hidden_loss,
|
output_hidden_states=self.do_calc_hidden_loss,
|
||||||
)
|
)
|
||||||
if self.different_base_models:
|
if self.different_base_models:
|
||||||
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
|
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
|
||||||
elif self.do_calc_hidden_loss:
|
elif self.do_calc_hidden_loss:
|
||||||
hid_loss_enc = self.calc_hidden_loss(
|
hid_loss_enc = self.calc_hidden_loss(
|
||||||
src_mask,
|
src_mask,
|
||||||
student_outputs.encoder_hidden_states,
|
student_outputs["encoder_hidden_states"],
|
||||||
all_teacher_encoder_outputs.hidden_states,
|
all_teacher_encoder_outputs["hidden_states"],
|
||||||
self.e_matches,
|
self.e_matches,
|
||||||
normalize_hidden=self.hparams.normalize_hidden,
|
normalize_hidden=self.hparams.normalize_hidden,
|
||||||
)
|
)
|
||||||
@@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule):
|
|||||||
use_cache=False, # since we are not passing labels, never let this default to True
|
use_cache=False, # since we are not passing labels, never let this default to True
|
||||||
)
|
)
|
||||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
|
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
|
||||||
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
|
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
|
||||||
hid_loss_dec = self.calc_hidden_loss(
|
hid_loss_dec = self.calc_hidden_loss(
|
||||||
dec_mask,
|
dec_mask,
|
||||||
student_outputs.decoder_hidden_states,
|
student_outputs["decoder_hidden_states"],
|
||||||
teacher_outputs.decoder_hidden_states,
|
teacher_outputs["decoder_hidden_states"],
|
||||||
self.d_matches,
|
self.d_matches,
|
||||||
normalize_hidden=self.hparams.normalize_hidden,
|
normalize_hidden=self.hparams.normalize_hidden,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user