[Seq2Seq] Correct import in Seq2Seq Trainer (#8254)

This commit is contained in:
Patrick von Platen
2020-11-03 13:56:41 +01:00
committed by GitHub
parent 504ff7bb12
commit 9f1747f999

View File

@@ -62,10 +62,7 @@ class Seq2SeqTrainer(Trainer):
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else: else:
# dynamically import label_smoothed_nll_loss # dynamically import label_smoothed_nll_loss
try: from utils import label_smoothed_nll_loss
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
self.loss_fn = label_smoothed_nll_loss self.loss_fn = label_smoothed_nll_loss