added hack for mismatched MNLI
This commit is contained in:
@@ -679,7 +679,6 @@ def main():
|
|||||||
output_modes = {
|
output_modes = {
|
||||||
"cola": "classification",
|
"cola": "classification",
|
||||||
"mnli": "classification",
|
"mnli": "classification",
|
||||||
"mnli-mm": "classification",
|
|
||||||
"mrpc": "classification",
|
"mrpc": "classification",
|
||||||
"sst-2": "classification",
|
"sst-2": "classification",
|
||||||
"sts-b": "regression",
|
"sts-b": "regression",
|
||||||
@@ -930,6 +929,8 @@ def main():
|
|||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
if output_mode == "classification":
|
if output_mode == "classification":
|
||||||
preds = np.argmax(preds, axis=1)
|
preds = np.argmax(preds, axis=1)
|
||||||
|
elif output_mode == "regression":
|
||||||
|
preds = np.squeeze(preds)
|
||||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||||
|
|
||||||
@@ -943,6 +944,69 @@ def main():
|
|||||||
for key in sorted(result.keys()):
|
for key in sorted(result.keys()):
|
||||||
logger.info(" %s = %s", key, str(result[key]))
|
logger.info(" %s = %s", key, str(result[key]))
|
||||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
|
# hack for MNLI-MM
|
||||||
|
if task_name == "mnli":
|
||||||
|
task_name = "mnli-mm"
|
||||||
|
processor = processors[task_name]()
|
||||||
|
|
||||||
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
|
eval_features = convert_examples_to_features(
|
||||||
|
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||||
|
logger.info("***** Running evaluation *****")
|
||||||
|
logger.info(" Num examples = %d", len(eval_examples))
|
||||||
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
|
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||||
|
|
||||||
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
|
# Run prediction for full data
|
||||||
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
eval_loss = 0
|
||||||
|
nb_eval_steps = 0
|
||||||
|
preds = []
|
||||||
|
|
||||||
|
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
input_mask = input_mask.to(device)
|
||||||
|
segment_ids = segment_ids.to(device)
|
||||||
|
label_ids = label_ids.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||||
|
|
||||||
|
eval_loss += tmp_eval_loss.mean().item()
|
||||||
|
nb_eval_steps += 1
|
||||||
|
if len(preds) == 0:
|
||||||
|
preds.append(logits.detach().cpu().numpy())
|
||||||
|
else:
|
||||||
|
preds[0] = np.append(
|
||||||
|
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||||
|
|
||||||
|
eval_loss = eval_loss / nb_eval_steps
|
||||||
|
preds = preds[0]
|
||||||
|
preds = np.argmax(preds, axis=1)
|
||||||
|
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||||
|
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||||
|
|
||||||
|
result['eval_loss'] = eval_loss
|
||||||
|
result['global_step'] = global_step
|
||||||
|
result['loss'] = loss
|
||||||
|
|
||||||
|
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
|
||||||
|
with open(output_eval_file, "w") as writer:
|
||||||
|
logger.info("***** Eval results *****")
|
||||||
|
for key in sorted(result.keys()):
|
||||||
|
logger.info(" %s = %s", key, str(result[key]))
|
||||||
|
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user