Naming update + XLNet/XLM evaluation

This commit is contained in:
LysandreJik
2019-12-04 10:37:00 -05:00
parent de276de1c1
commit 9ddc3f1a12
2 changed files with 85 additions and 18 deletions

View File

@@ -17,7 +17,7 @@
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions, compute_predictions_extended, squad_evaluate
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
import argparse
import logging
@@ -264,13 +264,13 @@ def evaluate(args, model, tokenizer, prefix=""):
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
predictions = compute_predictions_extended(examples, features, all_results, args.n_best_size,
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
predictions = compute_predictions(examples, features, all_results, args.n_best_size,
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)