DataParallel for SQuAD + fix XLM

This commit is contained in:
Lysandre
2019-12-10 19:21:20 +00:00
parent e6cff60b4c
commit dc4e9e5cb3
3 changed files with 15 additions and 2 deletions

View File

@@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""):
# XLNet and XLM use a more complex post-processing procedure
if args.model_type in ['xlnet', 'xlm']:
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file,
model.config.start_n_top, model.config.end_n_top,
start_n_top, end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,