[examples/s2s] clean up finetune_trainer (#7509)
This commit is contained in:
@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, data_args, *args, **kwargs):
|
||||
def __init__(self, config, data_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.pad_token_id = self.model.config.pad_token_id
|
||||
self.pad_token_id = self.config.pad_token_id
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
if self.args.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
assert logits.shape[-1] == self.model.config.vocab_size
|
||||
assert logits.shape[-1] == self.vocab_size
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user