add separator between data import and train
This commit is contained in:
@@ -52,6 +52,10 @@ def set_seed(args):
|
|||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Load dataset
|
||||||
|
# ------------
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
""" Abstracts the dataset used to train seq2seq models.
|
""" Abstracts the dataset used to train seq2seq models.
|
||||||
|
|
||||||
@@ -212,6 +216,11 @@ def load_and_cache_examples(args, tokenizer):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Train
|
||||||
|
# ------------
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Fine-tune the pretrained model on the corpus. """
|
""" Fine-tune the pretrained model on the corpus. """
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user