fix reloading model for evaluation in examples

This commit is contained in:
thomwolf
2018-12-13 14:48:12 +01:00
parent 0f544625f4
commit 087798b7fa
4 changed files with 65 additions and 17 deletions

View File

@@ -366,8 +366,7 @@ def main():
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
num_choices = 4
)
num_choices=4)
if args.fp16:
model.half()
model.to(device)
@@ -452,6 +451,9 @@ def main():
loss = loss * args.loss_scale
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
if args.fp16:
optimizer.backward(loss)
@@ -466,6 +468,20 @@ def main():
optimizer.zero_grad()
global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForMultipleChoice.from_pretrained(args.bert_model,
state_dict=model_state_dict,
num_choices=4)
if args.fp16:
model.half()
model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features(