Working evaluation
This commit is contained in:
@@ -16,7 +16,8 @@
|
|||||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor
|
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
||||||
|
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@@ -230,9 +231,11 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {
|
||||||
'attention_mask': batch[1]
|
'input_ids': batch[0],
|
||||||
}
|
'attention_mask': batch[1]
|
||||||
|
}
|
||||||
|
|
||||||
if args.model_type != 'distilbert':
|
if args.model_type != 'distilbert':
|
||||||
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
@@ -244,18 +247,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
eval_feature = features[example_index.item()]
|
eval_feature = features[example_index.item()]
|
||||||
unique_id = int(eval_feature.unique_id)
|
unique_id = int(eval_feature.unique_id)
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
|
||||||
# XLNet uses a more complex post-processing procedure
|
result = SquadResult([to_list(output[i]) for output in outputs] + [unique_id])
|
||||||
result = RawResultExtended(unique_id = unique_id,
|
|
||||||
start_top_log_probs = to_list(outputs[0][i]),
|
|
||||||
start_top_index = to_list(outputs[1][i]),
|
|
||||||
end_top_log_probs = to_list(outputs[2][i]),
|
|
||||||
end_top_index = to_list(outputs[3][i]),
|
|
||||||
cls_logits = to_list(outputs[4][i]))
|
|
||||||
else:
|
|
||||||
result = RawResult(unique_id = unique_id,
|
|
||||||
start_logits = to_list(outputs[0][i]),
|
|
||||||
end_logits = to_list(outputs[1][i]))
|
|
||||||
all_results.append(result)
|
all_results.append(result)
|
||||||
|
|
||||||
evalTime = timeit.default_timer() - start_time
|
evalTime = timeit.default_timer() - start_time
|
||||||
@@ -271,22 +264,18 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
# XLNet uses a more complex post-processing procedure
|
# XLNet uses a more complex post-processing procedure
|
||||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, output_prediction_file,
|
args.max_answer_length, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||||
model.config.start_n_top, model.config.end_n_top,
|
model.config.start_n_top, model.config.end_n_top,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||||
else:
|
else:
|
||||||
write_predictions(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||||
|
|
||||||
# Evaluate with the official SQuAD script
|
results = squad_evaluate(examples, predictions)
|
||||||
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
|
|
||||||
pred_file=output_prediction_file,
|
|
||||||
na_prob_file=output_null_log_odds_file)
|
|
||||||
results = evaluate_on_squad(evaluate_options)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||||
@@ -306,7 +295,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
|
|
||||||
processor = SquadV2Processor()
|
processor = SquadV2Processor()
|
||||||
examples = processor.get_dev_examples("examples/squad") if evaluate else processor.get_train_examples("examples/squad")
|
examples = processor.get_dev_examples("examples/squad", only_first=100) if evaluate else processor.get_train_examples("examples/squad")
|
||||||
|
# import tensorflow_datasets as tfds
|
||||||
|
# tfds_examples = tfds.load("squad")
|
||||||
|
# examples = SquadV1Processor().get_examples_from_dataset(tfds_examples["validation"])
|
||||||
|
|
||||||
features = squad_convert_examples_to_features(
|
features = squad_convert_examples_to_features(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|||||||
@@ -1,15 +1,323 @@
|
|||||||
|
""" Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was
|
||||||
|
modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
|
||||||
|
|
||||||
|
In addition to basic functionality, we also compute additional statistics and
|
||||||
|
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||||
|
This file is expected to map question ID's to the model's predicted probability
|
||||||
|
that a question is unanswerable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import collections
|
import collections
|
||||||
from io import open
|
from io import open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import string
|
||||||
|
import re
|
||||||
|
|
||||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_answer(s):
|
||||||
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||||
|
def remove_articles(text):
|
||||||
|
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||||
|
return re.sub(regex, ' ', text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
return ' '.join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return ''.join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens(s):
|
||||||
|
if not s:
|
||||||
|
return []
|
||||||
|
return normalize_answer(s).split()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_exact(a_gold, a_pred):
|
||||||
|
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_f1(a_gold, a_pred):
|
||||||
|
gold_toks = get_tokens(a_gold)
|
||||||
|
pred_toks = get_tokens(a_pred)
|
||||||
|
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
||||||
|
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
||||||
|
return int(gold_toks == pred_toks)
|
||||||
|
if num_same == 0:
|
||||||
|
return 0
|
||||||
|
precision = 1.0 * num_same / len(pred_toks)
|
||||||
|
recall = 1.0 * num_same / len(gold_toks)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
return f1
|
||||||
|
|
||||||
|
|
||||||
|
def get_raw_scores(examples, preds):
|
||||||
|
"""
|
||||||
|
Computes the exact and f1 scores from the examples and the model predictions
|
||||||
|
"""
|
||||||
|
exact_scores = {}
|
||||||
|
f1_scores = {}
|
||||||
|
|
||||||
|
for example in examples:
|
||||||
|
qas_id = example.qas_id
|
||||||
|
gold_answers = [answer['text'] for answer in example.answers if normalize_answer(answer['text'])]
|
||||||
|
|
||||||
|
if not gold_answers:
|
||||||
|
# For unanswerable questions, only correct answer is empty string
|
||||||
|
gold_answers = ['']
|
||||||
|
|
||||||
|
if qas_id not in preds:
|
||||||
|
print('Missing prediction for %s' % qas_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
prediction = preds[qas_id]
|
||||||
|
exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
|
||||||
|
f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
|
||||||
|
|
||||||
|
return exact_scores, f1_scores
|
||||||
|
|
||||||
|
|
||||||
|
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
||||||
|
new_scores = {}
|
||||||
|
for qid, s in scores.items():
|
||||||
|
pred_na = na_probs[qid] > na_prob_thresh
|
||||||
|
if pred_na:
|
||||||
|
new_scores[qid] = float(not qid_to_has_ans[qid])
|
||||||
|
else:
|
||||||
|
new_scores[qid] = s
|
||||||
|
return new_scores
|
||||||
|
|
||||||
|
|
||||||
|
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
||||||
|
if not qid_list:
|
||||||
|
total = len(exact_scores)
|
||||||
|
return collections.OrderedDict([
|
||||||
|
('exact', 100.0 * sum(exact_scores.values()) / total),
|
||||||
|
('f1', 100.0 * sum(f1_scores.values()) / total),
|
||||||
|
('total', total),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
total = len(qid_list)
|
||||||
|
return collections.OrderedDict([
|
||||||
|
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
||||||
|
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
||||||
|
('total', total),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def merge_eval(main_eval, new_eval, prefix):
|
||||||
|
for k in new_eval:
|
||||||
|
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||||
|
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||||
|
cur_score = num_no_ans
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = 0.0
|
||||||
|
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||||
|
for _, qid in enumerate(qid_list):
|
||||||
|
if qid not in scores:
|
||||||
|
continue
|
||||||
|
if qid_to_has_ans[qid]:
|
||||||
|
diff = scores[qid]
|
||||||
|
else:
|
||||||
|
if preds[qid]:
|
||||||
|
diff = -1
|
||||||
|
else:
|
||||||
|
diff = 0
|
||||||
|
cur_score += diff
|
||||||
|
if cur_score > best_score:
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = na_probs[qid]
|
||||||
|
return 100.0 * best_score / len(scores), best_thresh
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||||
|
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
|
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
|
|
||||||
|
main_eval['best_exact'] = best_exact
|
||||||
|
main_eval['best_exact_thresh'] = exact_thresh
|
||||||
|
main_eval['best_f1'] = best_f1
|
||||||
|
main_eval['best_f1_thresh'] = f1_thresh
|
||||||
|
|
||||||
|
|
||||||
|
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
||||||
|
qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
|
||||||
|
has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
|
||||||
|
no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
|
||||||
|
|
||||||
|
if no_answer_probs is None:
|
||||||
|
no_answer_probs = {k: 0.0 for k in preds}
|
||||||
|
|
||||||
|
exact, f1 = get_raw_scores(examples, preds)
|
||||||
|
|
||||||
|
exact_threshold = apply_no_ans_threshold(exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
||||||
|
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
||||||
|
|
||||||
|
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
||||||
|
|
||||||
|
if has_answer_qids:
|
||||||
|
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
||||||
|
merge_eval(evaluation, has_ans_eval, 'HasAns')
|
||||||
|
|
||||||
|
if no_answer_qids:
|
||||||
|
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
||||||
|
merge_eval(evaluation, no_ans_eval, 'NoAns')
|
||||||
|
|
||||||
|
if no_answer_probs:
|
||||||
|
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
||||||
|
|
||||||
|
return evaluation
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
|
"""Project the tokenized prediction back to the original text."""
|
||||||
|
|
||||||
|
# When we created the data, we kept track of the alignment between original
|
||||||
|
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
||||||
|
# now `orig_text` contains the span of our original text corresponding to the
|
||||||
|
# span that we predicted.
|
||||||
|
#
|
||||||
|
# However, `orig_text` may contain extra characters that we don't want in
|
||||||
|
# our prediction.
|
||||||
|
#
|
||||||
|
# For example, let's say:
|
||||||
|
# pred_text = steve smith
|
||||||
|
# orig_text = Steve Smith's
|
||||||
|
#
|
||||||
|
# We don't want to return `orig_text` because it contains the extra "'s".
|
||||||
|
#
|
||||||
|
# We don't want to return `pred_text` because it's already been normalized
|
||||||
|
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
||||||
|
# our tokenizer does additional normalization like stripping accent
|
||||||
|
# characters).
|
||||||
|
#
|
||||||
|
# What we really want to return is "Steve Smith".
|
||||||
|
#
|
||||||
|
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
||||||
|
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
||||||
|
# can fail in certain cases in which case we just return `orig_text`.
|
||||||
|
|
||||||
|
def _strip_spaces(text):
|
||||||
|
ns_chars = []
|
||||||
|
ns_to_s_map = collections.OrderedDict()
|
||||||
|
for (i, c) in enumerate(text):
|
||||||
|
if c == " ":
|
||||||
|
continue
|
||||||
|
ns_to_s_map[len(ns_chars)] = i
|
||||||
|
ns_chars.append(c)
|
||||||
|
ns_text = "".join(ns_chars)
|
||||||
|
return (ns_text, ns_to_s_map)
|
||||||
|
|
||||||
|
# We first tokenize `orig_text`, strip whitespace from the result
|
||||||
|
# and `pred_text`, and check if they are the same length. If they are
|
||||||
|
# NOT the same length, the heuristic has failed. If they are the same
|
||||||
|
# length, we assume the characters are one-to-one aligned.
|
||||||
|
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
||||||
|
|
||||||
|
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
||||||
|
|
||||||
|
start_position = tok_text.find(pred_text)
|
||||||
|
if start_position == -1:
|
||||||
|
if verbose_logging:
|
||||||
|
logger.info(
|
||||||
|
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
||||||
|
return orig_text
|
||||||
|
end_position = start_position + len(pred_text) - 1
|
||||||
|
|
||||||
|
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
||||||
|
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
||||||
|
|
||||||
|
if len(orig_ns_text) != len(tok_ns_text):
|
||||||
|
if verbose_logging:
|
||||||
|
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
||||||
|
orig_ns_text, tok_ns_text)
|
||||||
|
return orig_text
|
||||||
|
|
||||||
|
# We then project the characters in `pred_text` back to `orig_text` using
|
||||||
|
# the character-to-character alignment.
|
||||||
|
tok_s_to_ns_map = {}
|
||||||
|
for (i, tok_index) in tok_ns_to_s_map.items():
|
||||||
|
tok_s_to_ns_map[tok_index] = i
|
||||||
|
|
||||||
|
orig_start_position = None
|
||||||
|
if start_position in tok_s_to_ns_map:
|
||||||
|
ns_start_position = tok_s_to_ns_map[start_position]
|
||||||
|
if ns_start_position in orig_ns_to_s_map:
|
||||||
|
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||||
|
|
||||||
|
if orig_start_position is None:
|
||||||
|
if verbose_logging:
|
||||||
|
logger.info("Couldn't map start position")
|
||||||
|
return orig_text
|
||||||
|
|
||||||
|
orig_end_position = None
|
||||||
|
if end_position in tok_s_to_ns_map:
|
||||||
|
ns_end_position = tok_s_to_ns_map[end_position]
|
||||||
|
if ns_end_position in orig_ns_to_s_map:
|
||||||
|
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||||
|
|
||||||
|
if orig_end_position is None:
|
||||||
|
if verbose_logging:
|
||||||
|
logger.info("Couldn't map end position")
|
||||||
|
return orig_text
|
||||||
|
|
||||||
|
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
def _get_best_indexes(logits, n_best_size):
|
||||||
|
"""Get the n-best logits from a list."""
|
||||||
|
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
best_indexes = []
|
||||||
|
for i in range(len(index_and_score)):
|
||||||
|
if i >= n_best_size:
|
||||||
|
break
|
||||||
|
best_indexes.append(index_and_score[i][0])
|
||||||
|
return best_indexes
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_softmax(scores):
|
||||||
|
"""Compute softmax probability over raw logits."""
|
||||||
|
if not scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
max_score = None
|
||||||
|
for score in scores:
|
||||||
|
if max_score is None or score > max_score:
|
||||||
|
max_score = score
|
||||||
|
|
||||||
|
exp_scores = []
|
||||||
|
total_sum = 0.0
|
||||||
|
for score in scores:
|
||||||
|
x = math.exp(score - max_score)
|
||||||
|
exp_scores.append(x)
|
||||||
|
total_sum += x
|
||||||
|
|
||||||
|
probs = []
|
||||||
|
for score in exp_scores:
|
||||||
|
probs.append(score / total_sum)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
max_answer_length, do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||||
@@ -204,132 +512,192 @@ def compute_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
return all_predictions
|
return all_predictions
|
||||||
|
|
||||||
|
|
||||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
def compute_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||||
"""Project the tokenized prediction back to the original text."""
|
max_answer_length, output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file, orig_data_file,
|
||||||
|
start_n_top, end_n_top, version_2_with_negative,
|
||||||
|
tokenizer, verbose_logging):
|
||||||
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
|
|
||||||
# When we created the data, we kept track of the alignment between original
|
Requires utils_squad_evaluate.py
|
||||||
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
"""
|
||||||
# now `orig_text` contains the span of our original text corresponding to the
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
# span that we predicted.
|
"PrelimPrediction",
|
||||||
#
|
["feature_index", "start_index", "end_index",
|
||||||
# However, `orig_text` may contain extra characters that we don't want in
|
"start_log_prob", "end_log_prob"])
|
||||||
# our prediction.
|
|
||||||
#
|
|
||||||
# For example, let's say:
|
|
||||||
# pred_text = steve smith
|
|
||||||
# orig_text = Steve Smith's
|
|
||||||
#
|
|
||||||
# We don't want to return `orig_text` because it contains the extra "'s".
|
|
||||||
#
|
|
||||||
# We don't want to return `pred_text` because it's already been normalized
|
|
||||||
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
|
||||||
# our tokenizer does additional normalization like stripping accent
|
|
||||||
# characters).
|
|
||||||
#
|
|
||||||
# What we really want to return is "Steve Smith".
|
|
||||||
#
|
|
||||||
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
|
||||||
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
|
||||||
# can fail in certain cases in which case we just return `orig_text`.
|
|
||||||
|
|
||||||
def _strip_spaces(text):
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
ns_chars = []
|
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||||
ns_to_s_map = collections.OrderedDict()
|
|
||||||
for (i, c) in enumerate(text):
|
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||||
if c == " ":
|
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
|
|
||||||
|
example_index_to_features = collections.defaultdict(list)
|
||||||
|
for feature in all_features:
|
||||||
|
example_index_to_features[feature.example_index].append(feature)
|
||||||
|
|
||||||
|
unique_id_to_result = {}
|
||||||
|
for result in all_results:
|
||||||
|
unique_id_to_result[result.unique_id] = result
|
||||||
|
|
||||||
|
all_predictions = collections.OrderedDict()
|
||||||
|
all_nbest_json = collections.OrderedDict()
|
||||||
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
|
for (example_index, example) in enumerate(all_examples):
|
||||||
|
features = example_index_to_features[example_index]
|
||||||
|
|
||||||
|
prelim_predictions = []
|
||||||
|
# keep track of the minimum score of null start+end of position 0
|
||||||
|
score_null = 1000000 # large and positive
|
||||||
|
|
||||||
|
for (feature_index, feature) in enumerate(features):
|
||||||
|
result = unique_id_to_result[feature.unique_id]
|
||||||
|
|
||||||
|
cur_null_score = result.cls_logits
|
||||||
|
|
||||||
|
# if we could have irrelevant answers, get the min score of irrelevant
|
||||||
|
score_null = min(score_null, cur_null_score)
|
||||||
|
|
||||||
|
for i in range(start_n_top):
|
||||||
|
for j in range(end_n_top):
|
||||||
|
start_log_prob = result.start_top_log_probs[i]
|
||||||
|
start_index = result.start_top_index[i]
|
||||||
|
|
||||||
|
j_index = i * end_n_top + j
|
||||||
|
|
||||||
|
end_log_prob = result.end_top_log_probs[j_index]
|
||||||
|
end_index = result.end_top_index[j_index]
|
||||||
|
|
||||||
|
# We could hypothetically create invalid predictions, e.g., predict
|
||||||
|
# that the start of the span is in the question. We throw out all
|
||||||
|
# invalid predictions.
|
||||||
|
if start_index >= feature.paragraph_len - 1:
|
||||||
|
continue
|
||||||
|
if end_index >= feature.paragraph_len - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not feature.token_is_max_context.get(start_index, False):
|
||||||
|
continue
|
||||||
|
if end_index < start_index:
|
||||||
|
continue
|
||||||
|
length = end_index - start_index + 1
|
||||||
|
if length > max_answer_length:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prelim_predictions.append(
|
||||||
|
_PrelimPrediction(
|
||||||
|
feature_index=feature_index,
|
||||||
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
|
start_log_prob=start_log_prob,
|
||||||
|
end_log_prob=end_log_prob))
|
||||||
|
|
||||||
|
prelim_predictions = sorted(
|
||||||
|
prelim_predictions,
|
||||||
|
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
seen_predictions = {}
|
||||||
|
nbest = []
|
||||||
|
for pred in prelim_predictions:
|
||||||
|
if len(nbest) >= n_best_size:
|
||||||
|
break
|
||||||
|
feature = features[pred.feature_index]
|
||||||
|
|
||||||
|
# XLNet un-tokenizer
|
||||||
|
# Let's keep it simple for now and see if we need all this later.
|
||||||
|
#
|
||||||
|
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||||
|
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||||
|
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||||
|
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||||
|
# paragraph_text = example.paragraph_text
|
||||||
|
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||||
|
|
||||||
|
# Previously used Bert untokenizer
|
||||||
|
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
|
||||||
|
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
||||||
|
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
||||||
|
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
|
||||||
|
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
||||||
|
|
||||||
|
# Clean whitespace
|
||||||
|
tok_text = tok_text.strip()
|
||||||
|
tok_text = " ".join(tok_text.split())
|
||||||
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
|
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||||
|
verbose_logging)
|
||||||
|
|
||||||
|
if final_text in seen_predictions:
|
||||||
continue
|
continue
|
||||||
ns_to_s_map[len(ns_chars)] = i
|
|
||||||
ns_chars.append(c)
|
|
||||||
ns_text = "".join(ns_chars)
|
|
||||||
return (ns_text, ns_to_s_map)
|
|
||||||
|
|
||||||
# We first tokenize `orig_text`, strip whitespace from the result
|
seen_predictions[final_text] = True
|
||||||
# and `pred_text`, and check if they are the same length. If they are
|
|
||||||
# NOT the same length, the heuristic has failed. If they are the same
|
|
||||||
# length, we assume the characters are one-to-one aligned.
|
|
||||||
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
|
||||||
|
|
||||||
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
nbest.append(
|
||||||
|
_NbestPrediction(
|
||||||
|
text=final_text,
|
||||||
|
start_log_prob=pred.start_log_prob,
|
||||||
|
end_log_prob=pred.end_log_prob))
|
||||||
|
|
||||||
start_position = tok_text.find(pred_text)
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
if start_position == -1:
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
if verbose_logging:
|
if not nbest:
|
||||||
logger.info(
|
nbest.append(
|
||||||
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
|
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||||
return orig_text
|
end_log_prob=-1e6))
|
||||||
end_position = start_position + len(pred_text) - 1
|
|
||||||
|
|
||||||
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
total_scores = []
|
||||||
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
best_non_null_entry = None
|
||||||
|
for entry in nbest:
|
||||||
|
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
||||||
|
if not best_non_null_entry:
|
||||||
|
best_non_null_entry = entry
|
||||||
|
|
||||||
if len(orig_ns_text) != len(tok_ns_text):
|
probs = _compute_softmax(total_scores)
|
||||||
if verbose_logging:
|
|
||||||
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
|
|
||||||
orig_ns_text, tok_ns_text)
|
|
||||||
return orig_text
|
|
||||||
|
|
||||||
# We then project the characters in `pred_text` back to `orig_text` using
|
nbest_json = []
|
||||||
# the character-to-character alignment.
|
for (i, entry) in enumerate(nbest):
|
||||||
tok_s_to_ns_map = {}
|
output = collections.OrderedDict()
|
||||||
for (i, tok_index) in tok_ns_to_s_map.items():
|
output["text"] = entry.text
|
||||||
tok_s_to_ns_map[tok_index] = i
|
output["probability"] = probs[i]
|
||||||
|
output["start_log_prob"] = entry.start_log_prob
|
||||||
|
output["end_log_prob"] = entry.end_log_prob
|
||||||
|
nbest_json.append(output)
|
||||||
|
|
||||||
orig_start_position = None
|
assert len(nbest_json) >= 1
|
||||||
if start_position in tok_s_to_ns_map:
|
assert best_non_null_entry is not None
|
||||||
ns_start_position = tok_s_to_ns_map[start_position]
|
|
||||||
if ns_start_position in orig_ns_to_s_map:
|
|
||||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
|
||||||
|
|
||||||
if orig_start_position is None:
|
score_diff = score_null
|
||||||
if verbose_logging:
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
logger.info("Couldn't map start position")
|
# note(zhiliny): always predict best_non_null_entry
|
||||||
return orig_text
|
# and the evaluation script will search for the best threshold
|
||||||
|
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||||
|
|
||||||
orig_end_position = None
|
all_nbest_json[example.qas_id] = nbest_json
|
||||||
if end_position in tok_s_to_ns_map:
|
|
||||||
ns_end_position = tok_s_to_ns_map[end_position]
|
|
||||||
if ns_end_position in orig_ns_to_s_map:
|
|
||||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
|
||||||
|
|
||||||
if orig_end_position is None:
|
with open(output_prediction_file, "w") as writer:
|
||||||
if verbose_logging:
|
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||||
logger.info("Couldn't map end position")
|
|
||||||
return orig_text
|
|
||||||
|
|
||||||
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
|
with open(output_nbest_file, "w") as writer:
|
||||||
return output_text
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
if version_2_with_negative:
|
||||||
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
def _get_best_indexes(logits, n_best_size):
|
with open(orig_data_file, "r", encoding='utf-8') as reader:
|
||||||
"""Get the n-best logits from a list."""
|
orig_data = json.load(reader)["data"]
|
||||||
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
best_indexes = []
|
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||||
for i in range(len(index_and_score)):
|
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||||
if i >= n_best_size:
|
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||||
break
|
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
|
||||||
best_indexes.append(index_and_score[i][0])
|
out_eval = {}
|
||||||
return best_indexes
|
|
||||||
|
|
||||||
|
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
|
||||||
|
|
||||||
def _compute_softmax(scores):
|
return out_eval
|
||||||
"""Compute softmax probability over raw logits."""
|
|
||||||
if not scores:
|
|
||||||
return []
|
|
||||||
|
|
||||||
max_score = None
|
|
||||||
for score in scores:
|
|
||||||
if max_score is None or score > max_score:
|
|
||||||
max_score = score
|
|
||||||
|
|
||||||
exp_scores = []
|
|
||||||
total_sum = 0.0
|
|
||||||
for score in scores:
|
|
||||||
x = math.exp(score - max_score)
|
|
||||||
exp_scores.append(x)
|
|
||||||
total_sum += x
|
|
||||||
|
|
||||||
probs = []
|
|
||||||
for score in exp_scores:
|
|
||||||
probs.append(score / total_sum)
|
|
||||||
return probs
|
|
||||||
|
|||||||
@@ -306,13 +306,13 @@ class SquadProcessor(DataProcessor):
|
|||||||
else:
|
else:
|
||||||
is_impossible = False
|
is_impossible = False
|
||||||
|
|
||||||
if not is_impossible and is_training:
|
if not is_impossible:
|
||||||
if (len(qa["answers"]) != 1):
|
if is_training:
|
||||||
raise ValueError(
|
answer = qa["answers"][0]
|
||||||
"For training, each question should have exactly 1 answer.")
|
answer_text = answer['text']
|
||||||
answer = qa["answers"][0]
|
start_position_character = answer['answer_start']
|
||||||
answer_text = answer['text']
|
else:
|
||||||
start_position_character = answer['answer_start']
|
answers = qa["answers"]
|
||||||
|
|
||||||
example = SquadExample(
|
example = SquadExample(
|
||||||
qas_id=qas_id,
|
qas_id=qas_id,
|
||||||
@@ -321,7 +321,8 @@ class SquadProcessor(DataProcessor):
|
|||||||
answer_text=answer_text,
|
answer_text=answer_text,
|
||||||
start_position_character=start_position_character,
|
start_position_character=start_position_character,
|
||||||
title=title,
|
title=title,
|
||||||
is_impossible=is_impossible
|
is_impossible=is_impossible,
|
||||||
|
answers=answers
|
||||||
)
|
)
|
||||||
|
|
||||||
examples.append(example)
|
examples.append(example)
|
||||||
@@ -352,6 +353,7 @@ class SquadExample(object):
|
|||||||
answer_text,
|
answer_text,
|
||||||
start_position_character,
|
start_position_character,
|
||||||
title,
|
title,
|
||||||
|
answers=None,
|
||||||
is_impossible=False):
|
is_impossible=False):
|
||||||
self.qas_id = qas_id
|
self.qas_id = qas_id
|
||||||
self.question_text = question_text
|
self.question_text = question_text
|
||||||
@@ -359,6 +361,7 @@ class SquadExample(object):
|
|||||||
self.answer_text = answer_text
|
self.answer_text = answer_text
|
||||||
self.title = title
|
self.title = title
|
||||||
self.is_impossible = is_impossible
|
self.is_impossible = is_impossible
|
||||||
|
self.answers = answers
|
||||||
|
|
||||||
self.start_position, self.end_position = 0, 0
|
self.start_position, self.end_position = 0, 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user