Naming update + XLNet/XLM evaluation
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
||||||
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
|
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, output_prediction_file,
|
args.max_answer_length, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||||
model.config.start_n_top, model.config.end_n_top,
|
model.config.start_n_top, model.config.end_n_top,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||||
else:
|
else:
|
||||||
predictions = compute_predictions(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||||
|
|||||||
@@ -125,6 +125,53 @@ def merge_eval(main_eval, new_eval, prefix):
|
|||||||
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||||
|
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||||
|
cur_score = num_no_ans
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = 0.0
|
||||||
|
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||||
|
for i, qid in enumerate(qid_list):
|
||||||
|
if qid not in scores:
|
||||||
|
continue
|
||||||
|
if qid_to_has_ans[qid]:
|
||||||
|
diff = scores[qid]
|
||||||
|
else:
|
||||||
|
if preds[qid]:
|
||||||
|
diff = -1
|
||||||
|
else:
|
||||||
|
diff = 0
|
||||||
|
cur_score += diff
|
||||||
|
if cur_score > best_score:
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = na_probs[qid]
|
||||||
|
|
||||||
|
has_ans_score, has_ans_cnt = 0, 0
|
||||||
|
for qid in qid_list:
|
||||||
|
if not qid_to_has_ans[qid]:
|
||||||
|
continue
|
||||||
|
has_ans_cnt += 1
|
||||||
|
|
||||||
|
if qid not in scores:
|
||||||
|
continue
|
||||||
|
has_ans_score += scores[qid]
|
||||||
|
|
||||||
|
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||||
|
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(
|
||||||
|
preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
|
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(
|
||||||
|
preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
|
main_eval['best_exact'] = best_exact
|
||||||
|
main_eval['best_exact_thresh'] = exact_thresh
|
||||||
|
main_eval['best_f1'] = best_f1
|
||||||
|
main_eval['best_f1_thresh'] = f1_thresh
|
||||||
|
main_eval['has_ans_exact'] = has_ans_exact
|
||||||
|
main_eval['has_ans_f1'] = has_ans_f1
|
||||||
|
|
||||||
|
|
||||||
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||||
cur_score = num_no_ans
|
cur_score = num_no_ans
|
||||||
@@ -318,10 +365,20 @@ def _compute_softmax(scores):
|
|||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
def compute_predictions_logits(
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
all_examples,
|
||||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
all_features,
|
||||||
version_2_with_negative, null_score_diff_threshold):
|
all_results,
|
||||||
|
n_best_size,
|
||||||
|
max_answer_length,
|
||||||
|
do_lower_case,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
verbose_logging,
|
||||||
|
version_2_with_negative,
|
||||||
|
null_score_diff_threshold
|
||||||
|
):
|
||||||
"""Write final predictions to the json file and log-odds of null if needed."""
|
"""Write final predictions to the json file and log-odds of null if needed."""
|
||||||
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
logger.info("Writing predictions to: %s" % (output_prediction_file))
|
||||||
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
@@ -450,12 +507,12 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
text="",
|
text="",
|
||||||
start_logit=null_start_logit,
|
start_logit=null_start_logit,
|
||||||
end_logit=null_end_logit))
|
end_logit=null_end_logit))
|
||||||
|
|
||||||
# In very rare edge cases we could only have single null prediction.
|
# In very rare edge cases we could only have single null prediction.
|
||||||
# So we just create a nonce prediction in this case to avoid failure.
|
# So we just create a nonce prediction in this case to avoid failure.
|
||||||
if len(nbest)==1:
|
if len(nbest) == 1:
|
||||||
nbest.insert(0,
|
nbest.insert(0,
|
||||||
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
||||||
|
|
||||||
# In very rare edge cases we could have no valid predictions. So we
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
# just create a nonce prediction in this case to avoid failure.
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
@@ -512,12 +569,22 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
return all_predictions
|
return all_predictions
|
||||||
|
|
||||||
|
|
||||||
def compute_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
def compute_predictions_log_probs(
|
||||||
max_answer_length, output_prediction_file,
|
all_examples,
|
||||||
output_nbest_file,
|
all_features,
|
||||||
output_null_log_odds_file, orig_data_file,
|
all_results,
|
||||||
start_n_top, end_n_top, version_2_with_negative,
|
n_best_size,
|
||||||
tokenizer, verbose_logging):
|
max_answer_length,
|
||||||
|
output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file,
|
||||||
|
orig_data_file,
|
||||||
|
start_n_top,
|
||||||
|
end_n_top,
|
||||||
|
version_2_with_negative,
|
||||||
|
tokenizer,
|
||||||
|
verbose_logging
|
||||||
|
):
|
||||||
""" XLNet write prediction logic (more complex than Bert's).
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
Write final predictions to the json file and log-odds of null if needed.
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
|
|
||||||
@@ -526,7 +593,7 @@ def compute_predictions_extended(all_examples, all_features, all_results, n_best
|
|||||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"PrelimPrediction",
|
"PrelimPrediction",
|
||||||
["feature_index", "start_index", "end_index",
|
["feature_index", "start_index", "end_index",
|
||||||
"start_log_prob", "end_log_prob"])
|
"start_log_prob", "end_log_prob"])
|
||||||
|
|
||||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||||
@@ -609,7 +676,7 @@ def compute_predictions_extended(all_examples, all_features, all_results, n_best
|
|||||||
|
|
||||||
# XLNet un-tokenizer
|
# XLNet un-tokenizer
|
||||||
# Let's keep it simple for now and see if we need all this later.
|
# Let's keep it simple for now and see if we need all this later.
|
||||||
#
|
#
|
||||||
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||||
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||||
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||||
|
|||||||
Reference in New Issue
Block a user