[examples] fix summarization do_predict (#3866)
This commit is contained in:
@@ -166,8 +166,12 @@ def main(args):
|
||||
|
||||
# Optionally, predict on dev set and write to output_dir
|
||||
if args.do_predict:
|
||||
# See https://github.com/huggingface/transformers/issues/3159
|
||||
# pl use this format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
|
||||
model = model.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
args_d.update({"do_train": False, "do_predict": True})
|
||||
|
||||
main(argparse.Namespace(**args_d))
|
||||
contents = os.listdir(output_dir)
|
||||
expected_contents = {
|
||||
"checkpointepoch=0.ckpt",
|
||||
"test_results.txt",
|
||||
}
|
||||
created_files = {os.path.basename(p) for p in contents}
|
||||
self.assertSetEqual(expected_contents, created_files)
|
||||
|
||||
def test_t5_run_sum_cli(self):
|
||||
args_d: dict = DEFAULT_ARGS.copy()
|
||||
@@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
|
||||
do_predict=True,
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
|
||||
# args_d.update({"do_train": False, "do_predict": True})
|
||||
# main(argparse.Namespace(**args_d))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user