[Seq2Seq] Correct import in Seq2Seq Trainer (#8254)
This commit is contained in:
committed by
GitHub
parent
504ff7bb12
commit
9f1747f999
@@ -62,9 +62,6 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||||
else:
|
else:
|
||||||
# dynamically import label_smoothed_nll_loss
|
# dynamically import label_smoothed_nll_loss
|
||||||
try:
|
|
||||||
from .utils import label_smoothed_nll_loss
|
|
||||||
except ImportError:
|
|
||||||
from utils import label_smoothed_nll_loss
|
from utils import label_smoothed_nll_loss
|
||||||
|
|
||||||
self.loss_fn = label_smoothed_nll_loss
|
self.loss_fn = label_smoothed_nll_loss
|
||||||
|
|||||||
Reference in New Issue
Block a user