no fp16 on evaluation
This commit is contained in:
@@ -558,8 +558,6 @@ def main():
|
|||||||
# Load a trained model that you have fine-tuned
|
# Load a trained model that you have fine-tuned
|
||||||
model_state_dict = torch.load(output_model_file)
|
model_state_dict = torch.load(output_model_file)
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
if args.fp16:
|
|
||||||
model.half()
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
|||||||
@@ -923,8 +923,6 @@ def main():
|
|||||||
# Load a trained model that you have fine-tuned
|
# Load a trained model that you have fine-tuned
|
||||||
model_state_dict = torch.load(output_model_file)
|
model_state_dict = torch.load(output_model_file)
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||||
if args.fp16:
|
|
||||||
model.half()
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
|||||||
@@ -478,8 +478,6 @@ def main():
|
|||||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||||
state_dict=model_state_dict,
|
state_dict=model_state_dict,
|
||||||
num_choices=4)
|
num_choices=4)
|
||||||
if args.fp16:
|
|
||||||
model.half()
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
|||||||
Reference in New Issue
Block a user