loading saved model when n_classes != 2
Required to for: Assertion `t >= 0 && t < n_classes` failed, if your default number of classes is not 2.
This commit is contained in:
@@ -558,7 +558,7 @@ def main():
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
|
||||
Reference in New Issue
Block a user