[examples] summarization/bart/finetune.py supports t5 (#3824)

renames `run_bart_sum.py` to `finetune.py`
This commit is contained in:
Sam Shleifer
2020-04-16 15:15:19 -04:00
committed by GitHub
parent 0cec4fab7d
commit f0c96fafd1
5 changed files with 36 additions and 14 deletions

View File

@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .run_bart_sum import main
from .finetune import main
from .utils import SummarizationDataset
@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d.update(
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
args = argparse.Namespace(**args_d)
main(args)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
tmp_dir = make_test_data_dir()
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
train_batch_size=2,
eval_batch_size=2,
n_gpu=0,
output_dir=output_dir,
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())