diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index f05a5847ed..2e8d0aa250 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -52,6 +52,10 @@ def set_seed(args): torch.manual_seed(args.seed) +# ------------ +# Load dataset +# ------------ + class TextDataset(Dataset): """ Abstracts the dataset used to train seq2seq models. @@ -212,6 +216,11 @@ def load_and_cache_examples(args, tokenizer): return dataset +# ------------ +# Train +# ------------ + + def train(args, train_dataset, model, tokenizer): """ Fine-tune the pretrained model on the corpus. """ raise NotImplementedError