[examples] fix summarization do_predict (#3866)

This commit is contained in:
Sam Shleifer
2020-04-20 10:49:56 -04:00
committed by GitHub
parent 52c85f847a
commit a504cb49ec
3 changed files with 20 additions and 6 deletions

View File

@@ -166,8 +166,12 @@ def main(args):
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
model = model.load_from_checkpoint(checkpoints[-1])
trainer.test(model)

View File

@@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
)
main(argparse.Namespace(**args_d))
args_d.update({"do_train": False, "do_predict": True})
main(argparse.Namespace(**args_d))
contents = os.listdir(output_dir)
expected_contents = {
"checkpointepoch=0.ckpt",
"test_results.txt",
}
created_files = {os.path.basename(p) for p in contents}
self.assertSetEqual(expected_contents, created_files)
def test_t5_run_sum_cli(self):
args_d: dict = DEFAULT_ARGS.copy()
@@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
do_predict=True,
)
main(argparse.Namespace(**args_d))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))