diff --git a/examples/run_classifier.py b/examples/run_classifier.py index adf81f4e28..456b06b07f 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -558,7 +558,7 @@ def main(): # Load a trained model that you have fine-tuned model_state_dict = torch.load(output_model_file) - model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) + model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels) model.to(device) if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):