Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
@@ -207,9 +207,9 @@ if __name__ == "__main__":
|
||||
|
||||
if args.do_predict:
|
||||
# See https://github.com/huggingface/transformers/issues/3159
|
||||
# pl use this format to create a checkpoint:
|
||||
# pl use this default format to create a checkpoint:
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
|
||||
model = model.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
Reference in New Issue
Block a user