@@ -148,7 +148,7 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.save_readable_batch(batch)
|
self.save_readable_batch(batch)
|
||||||
|
|
||||||
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||||
lm_logits = outputs[0]
|
lm_logits = outputs["logits"]
|
||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user