[seq2seq] make it easier to run the scripts (#7274)
This commit is contained in:
9
examples/seq2seq/finetune.py
Normal file → Executable file
9
examples/seq2seq/finetune.py
Normal file → Executable file
@@ -1,7 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@@ -13,7 +16,6 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
@@ -34,6 +36,11 @@ from utils import (
|
||||
)
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user