From e5c393dcebf42eaec9c1e1d619b5a7788a2d7c65 Mon Sep 17 00:00:00 2001 From: Ethan Perez Date: Mon, 30 Mar 2020 16:06:08 -0500 Subject: [PATCH] =?UTF-8?q?[Bug=20fix]=20Using=20loaded=20checkpoint=20wit?= =?UTF-8?q?h=20--do=5Fpredict=20(instead=20of=E2=80=A6=20(#3437)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Using loaded checkpoint with --do_predict Without this fix, I'm getting near-random validation performance for a trained model, and the validation performance differs per validation run. I think this happens since the `model` variable isn't set with the loaded checkpoint, so I'm using a randomly initialized model. Looking at the model activations, they differ each time I run evaluation (but they don't with this fix). * Update checkpoint loading * Fixing model loading --- examples/glue/run_pl_glue.py | 2 +- examples/ner/run_pl_ner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/glue/run_pl_glue.py b/examples/glue/run_pl_glue.py index 18361f36db..44031cce5f 100644 --- a/examples/glue/run_pl_glue.py +++ b/examples/glue/run_pl_glue.py @@ -192,5 +192,5 @@ if __name__ == "__main__": # Optionally, predict on dev set and write to output_dir if args.do_predict: checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) - GLUETransformer.load_from_checkpoint(checkpoints[-1]) + model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model) diff --git a/examples/ner/run_pl_ner.py b/examples/ner/run_pl_ner.py index 6b484faa38..f5cbf5bd3f 100644 --- a/examples/ner/run_pl_ner.py +++ b/examples/ner/run_pl_ner.py @@ -192,5 +192,5 @@ if __name__ == "__main__": # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ # /pytorch_lightning/callbacks/model_checkpoint.py#L169 checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) - NERTransformer.load_from_checkpoint(checkpoints[-1]) + model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model)