From 9ddc3f1a1227fc9cbe4e5a5c20b21546e438dfb1 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 4 Dec 2019 10:37:00 -0500 Subject: [PATCH] Naming update + XLNet/XLM evaluation --- examples/run_squad.py | 6 +- transformers/data/metrics/squad_metrics.py | 97 ++++++++++++++++++---- 2 files changed, 85 insertions(+), 18 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index b7952487dc..a9ef5c6ba2 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult -from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate +from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate import argparse import logging @@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""): if args.model_type in ['xlnet', 'xlm']: # XLNet uses a more complex post-processing procedure - predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size, + predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size, args.max_answer_length, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.predict_file, model.config.start_n_top, model.config.end_n_top, args.version_2_with_negative, tokenizer, args.verbose_logging) else: - predictions = compute_predictions(examples, features, all_results, args.n_best_size, + predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, args.verbose_logging, args.version_2_with_negative, args.null_score_diff_threshold) diff --git a/transformers/data/metrics/squad_metrics.py b/transformers/data/metrics/squad_metrics.py index 83647a20d0..1f120d354a 100644 --- a/transformers/data/metrics/squad_metrics.py +++ b/transformers/data/metrics/squad_metrics.py @@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix): main_eval['%s_%s' % (prefix, k)] = new_eval[k] +def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: + continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + + has_ans_score, has_ans_cnt = 0, 0 + for qid in qid_list: + if not qid_to_has_ans[qid]: + continue + has_ans_cnt += 1 + + if qid not in scores: + continue + has_ans_score += scores[qid] + + return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt + + +def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2( + preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2( + preds, f1_raw, na_probs, qid_to_has_ans) + main_eval['best_exact'] = best_exact + main_eval['best_exact_thresh'] = exact_thresh + main_eval['best_f1'] = best_f1 + main_eval['best_f1_thresh'] = f1_thresh + main_eval['has_ans_exact'] = has_ans_exact + main_eval['has_ans_f1'] = has_ans_f1 + + def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) cur_score = num_no_ans @@ -318,10 +365,20 @@ def _compute_softmax(scores): return probs -def compute_predictions(all_examples, all_features, all_results, n_best_size, - max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file, output_null_log_odds_file, verbose_logging, - version_2_with_negative, null_score_diff_threshold): +def compute_predictions_logits( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + do_lower_case, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + verbose_logging, + version_2_with_negative, + null_score_diff_threshold +): """Write final predictions to the json file and log-odds of null if needed.""" logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing nbest to: %s" % (output_nbest_file)) @@ -450,12 +507,12 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size, text="", start_logit=null_start_logit, end_logit=null_end_logit)) - + # In very rare edge cases we could only have single null prediction. # So we just create a nonce prediction in this case to avoid failure. - if len(nbest)==1: + if len(nbest) == 1: nbest.insert(0, - _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) + _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) # In very rare edge cases we could have no valid predictions. So we # just create a nonce prediction in this case to avoid failure. @@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size, return all_predictions -def compute_predictions_extended(all_examples, all_features, all_results, n_best_size, - max_answer_length, output_prediction_file, - output_nbest_file, - output_null_log_odds_file, orig_data_file, - start_n_top, end_n_top, version_2_with_negative, - tokenizer, verbose_logging): +def compute_predictions_log_probs( + all_examples, + all_features, + all_results, + n_best_size, + max_answer_length, + output_prediction_file, + output_nbest_file, + output_null_log_odds_file, + orig_data_file, + start_n_top, + end_n_top, + version_2_with_negative, + tokenizer, + verbose_logging +): """ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of null if needed. @@ -526,7 +593,7 @@ def compute_predictions_extended(all_examples, all_features, all_results, n_best _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name "PrelimPrediction", ["feature_index", "start_index", "end_index", - "start_log_prob", "end_log_prob"]) + "start_log_prob", "end_log_prob"]) _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) @@ -609,7 +676,7 @@ def compute_predictions_extended(all_examples, all_features, all_results, n_best # XLNet un-tokenizer # Let's keep it simple for now and see if we need all this later. - # + # # tok_start_to_orig_index = feature.tok_start_to_orig_index # tok_end_to_orig_index = feature.tok_end_to_orig_index # start_orig_pos = tok_start_to_orig_index[pred.start_index]