Naming update + XLNet/XLM evaluation

This commit is contained in:
LysandreJik
2019-12-04 10:37:00 -05:00
parent de276de1c1
commit 9ddc3f1a12
2 changed files with 85 additions and 18 deletions

View File

@@ -17,7 +17,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
import argparse import argparse
import logging import logging
@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size, predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file, output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top, model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
predictions = compute_predictions(examples, features, all_results, args.n_best_size, predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold)

View File

@@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix):
main_eval['%s_%s' % (prefix, k)] = new_eval[k] main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for i, qid in enumerate(qid_list):
if qid not in scores:
continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if preds[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
has_ans_score, has_ans_cnt = 0, 0
for qid in qid_list:
if not qid_to_has_ans[qid]:
continue
has_ans_cnt += 1
if qid not in scores:
continue
has_ans_score += scores[qid]
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
preds, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
preds, f1_raw, na_probs, qid_to_has_ans)
main_eval['best_exact'] = best_exact
main_eval['best_exact_thresh'] = exact_thresh
main_eval['best_f1'] = best_f1
main_eval['best_f1_thresh'] = f1_thresh
main_eval['has_ans_exact'] = has_ans_exact
main_eval['has_ans_f1'] = has_ans_f1
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans cur_score = num_no_ans
@@ -318,10 +365,20 @@ def _compute_softmax(scores):
return probs return probs
def compute_predictions(all_examples, all_features, all_results, n_best_size, def compute_predictions_logits(
max_answer_length, do_lower_case, output_prediction_file, all_examples,
output_nbest_file, output_null_log_odds_file, verbose_logging, all_features,
version_2_with_negative, null_score_diff_threshold): all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
verbose_logging,
version_2_with_negative,
null_score_diff_threshold
):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
@@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
return all_predictions return all_predictions
def compute_predictions_extended(all_examples, all_features, all_results, n_best_size, def compute_predictions_log_probs(
max_answer_length, output_prediction_file, all_examples,
all_features,
all_results,
n_best_size,
max_answer_length,
output_prediction_file,
output_nbest_file, output_nbest_file,
output_null_log_odds_file, orig_data_file, output_null_log_odds_file,
start_n_top, end_n_top, version_2_with_negative, orig_data_file,
tokenizer, verbose_logging): start_n_top,
end_n_top,
version_2_with_negative,
tokenizer,
verbose_logging
):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.