[examples/s2s] clean up finetune_trainer (#7509)

This commit is contained in:
Suraj Patil
2020-10-01 21:49:29 +05:30
committed by GitHub
parent bd2621583b
commit 72d363d979
3 changed files with 107 additions and 105 deletions

View File

@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer):
def __init__(self, data_args, *args, **kwargs):
def __init__(self, config, data_args, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.pad_token_id = self.model.config.pad_token_id
self.pad_token_id = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
assert logits.shape[-1] == self.model.config.vocab_size
assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)