[examples] fix summarization do_predict (#3866)

This commit is contained in:
Sam Shleifer
2020-04-20 10:49:56 -04:00
committed by GitHub
parent 52c85f847a
commit a504cb49ec
3 changed files with 20 additions and 6 deletions

View File

@@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
contents = os.listdir(output_dir)
expected_contents = {
"checkpointepoch=0.ckpt",
"test_results.txt",
}
created_files = {os.path.basename(p) for p in contents}
self.assertSetEqual(expected_contents, created_files)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
@@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))