Improve global visibility on the run_squad script, remove unused files and fixes related to XLNet.
This commit is contained in:
LysandreJik
2019-12-05 16:01:51 -05:00
parent 9ecd83dace
commit e9217da5ff
5 changed files with 45 additions and 1387 deletions

View File

@@ -578,7 +578,6 @@ def compute_predictions_log_probs(
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
orig_data_file,
start_n_top,
end_n_top,
version_2_with_negative,
@@ -756,15 +755,4 @@ def compute_predictions_log_probs(
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
with open(orig_data_file, "r", encoding='utf-8') as reader:
orig_data = json.load(reader)["data"]
qid_to_has_ans = make_qid_to_has_ans(orig_data)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
out_eval = {}
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
return out_eval
return all_predictions

View File

@@ -9,7 +9,7 @@ from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
from .utils import DataProcessor, InputExample, InputFeatures
from ...file_utils import is_tf_available, is_torch_available
if is_torch_available:
if is_torch_available():
import torch
from torch.utils.data import TensorDataset