Black 20 release
This commit is contained in:
@@ -348,7 +348,10 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
||||
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
||||
|
||||
Reference in New Issue
Block a user