From 99709ee61d887ac1a4431a54a4f78f008b5b11d6 Mon Sep 17 00:00:00 2001 From: Jasdeep Singh <33911313+SinghJasdeep@users.noreply.github.com> Date: Thu, 20 Dec 2018 13:55:47 -0800 Subject: [PATCH] loading saved model when n_classes != 2 Required to for: Assertion `t >= 0 && t < n_classes` failed, if your default number of classes is not 2. --- examples/run_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index adf81f4e28..456b06b07f 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -558,7 +558,7 @@ def main(): # Load a trained model that you have fine-tuned model_state_dict = torch.load(output_model_file) - model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) + model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) model.to(device) if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):