fix reloading model for evaluation in examples
This commit is contained in:
@@ -558,6 +558,9 @@ def main():
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
|
||||
@@ -923,6 +923,9 @@ def main():
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_squad_examples(
|
||||
|
||||
@@ -366,8 +366,7 @@ def main():
|
||||
# Prepare model
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
num_choices = 4
|
||||
)
|
||||
num_choices=4)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
@@ -452,6 +451,9 @@ def main():
|
||||
loss = loss * args.loss_scale
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
@@ -466,6 +468,20 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
# Save a trained model
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
|
||||
# Load a trained model that you have fine-tuned
|
||||
model_state_dict = torch.load(output_model_file)
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
state_dict=model_state_dict,
|
||||
num_choices=4)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
|
||||
eval_features = convert_examples_to_features(
|
||||
|
||||
Reference in New Issue
Block a user