diff --git a/examples/research_projects/distillation/distiller.py b/examples/research_projects/distillation/distiller.py index a9716506c1..fc5dc58941 100644 --- a/examples/research_projects/distillation/distiller.py +++ b/examples/research_projects/distillation/distiller.py @@ -380,21 +380,19 @@ class Distiller: lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM). """ if self.mlm: - s_logits, s_hidden_states = self.student( + student_outputs = self.student( input_ids=input_ids, attention_mask=attention_mask ) # (bs, seq_length, voc_size) with torch.no_grad(): - t_logits, t_hidden_states = self.teacher( + teacher_outputs = self.teacher( input_ids=input_ids, attention_mask=attention_mask ) # (bs, seq_length, voc_size) else: - s_logits, _, s_hidden_states = self.student( - input_ids=input_ids, attention_mask=None - ) # (bs, seq_length, voc_size) + student_outputs = self.student(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) with torch.no_grad(): - t_logits, _, t_hidden_states = self.teacher( - input_ids=input_ids, attention_mask=None - ) # (bs, seq_length, voc_size) + teacher_outputs = self.teacher(input_ids=input_ids, attention_mask=None) # (bs, seq_length, voc_size) + s_logits, s_hidden_states = student_outputs["logits"], student_outputs["hidden_states"] + t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs["hidden_states"] assert s_logits.size() == t_logits.size() # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100