[Seq2SeqTrainer] Move import to init to make file self-contained (#8194)
* boom boom * reverse order
This commit is contained in:
committed by
GitHub
parent
1f12934df4
commit
9bd30f7cf4
@@ -20,12 +20,6 @@ from transformers.optimization import (
|
|||||||
from transformers.trainer_pt_utils import get_tpu_sampler
|
from transformers.trainer_pt_utils import get_tpu_sampler
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import label_smoothed_nll_loss
|
|
||||||
except ImportError:
|
|
||||||
from utils import label_smoothed_nll_loss
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
arg_to_scheduler = {
|
arg_to_scheduler = {
|
||||||
@@ -64,6 +58,17 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.args.label_smoothing == 0:
|
||||||
|
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||||||
|
else:
|
||||||
|
# dynamically import label_smoothed_nll_loss
|
||||||
|
try:
|
||||||
|
from .utils import label_smoothed_nll_loss
|
||||||
|
except ImportError:
|
||||||
|
from utils import label_smoothed_nll_loss
|
||||||
|
|
||||||
|
self.loss_fn = label_smoothed_nll_loss
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
||||||
"""
|
"""
|
||||||
Setup the optimizer and the learning rate scheduler.
|
Setup the optimizer and the learning rate scheduler.
|
||||||
@@ -135,9 +140,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
|
||||||
# force training to ignore pad token
|
# force training to ignore pad token
|
||||||
logits = model(**inputs, use_cache=False)[0]
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
|
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
|
||||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
|
||||||
else:
|
else:
|
||||||
# compute usual loss via models
|
# compute usual loss via models
|
||||||
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
|
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
|
||||||
@@ -145,9 +148,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
# compute label smoothed loss
|
# compute label smoothed loss
|
||||||
logits = model(**inputs, use_cache=False)[0]
|
logits = model(**inputs, use_cache=False)[0]
|
||||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
loss, _ = label_smoothed_nll_loss(
|
loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
|
||||||
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
|
|
||||||
)
|
|
||||||
return loss, logits
|
return loss, logits
|
||||||
|
|
||||||
def compute_loss(self, model, inputs):
|
def compute_loss(self, model, inputs):
|
||||||
|
|||||||
Reference in New Issue
Block a user