Upgrade PyTorch Lightning to 1.0.2 (#7852)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sean Naren
2020-10-28 18:59:14 +00:00
committed by GitHub
parent 1b6c8d4811
commit 5e24982e58
8 changed files with 11 additions and 13 deletions

View File

@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
if args.do_predict:
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
model = model.load_from_checkpoint(checkpoints[-1])
return trainer.test(model)