DataParallel for SQuAD + fix XLM
This commit is contained in:
@@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
|
|
||||||
# XLNet and XLM use a more complex post-processing procedure
|
# XLNet and XLM use a more complex post-processing procedure
|
||||||
if args.model_type in ['xlnet', 'xlm']:
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
|
|
||||||
|
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
||||||
|
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
||||||
|
|
||||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||||
args.max_answer_length, output_prediction_file,
|
args.max_answer_length, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file,
|
output_nbest_file, output_null_log_odds_file,
|
||||||
model.config.start_n_top, model.config.end_n_top,
|
start_n_top, end_n_top,
|
||||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||||
else:
|
else:
|
||||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||||
|
|||||||
@@ -695,7 +695,12 @@ def compute_predictions_log_probs(
|
|||||||
tok_text = " ".join(tok_text.split())
|
tok_text = " ".join(tok_text.split())
|
||||||
orig_text = " ".join(orig_tokens)
|
orig_text = " ".join(orig_tokens)
|
||||||
|
|
||||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
if hasattr(tokenizer, "do_lower_case"):
|
||||||
|
do_lower_case = tokenizer.do_lower_case
|
||||||
|
else:
|
||||||
|
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
||||||
|
|
||||||
|
final_text = get_final_text(tok_text, orig_text, do_lower_case,
|
||||||
verbose_logging)
|
verbose_logging)
|
||||||
|
|
||||||
if final_text in seen_predictions:
|
if final_text in seen_predictions:
|
||||||
|
|||||||
@@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer):
|
|||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||||
|
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||||
|
|
||||||
# cache of sm.MosesPunctNormalizer instance
|
# cache of sm.MosesPunctNormalizer instance
|
||||||
self.cache_moses_punct_normalizer = dict()
|
self.cache_moses_punct_normalizer = dict()
|
||||||
# cache of sm.MosesTokenizer instance
|
# cache of sm.MosesTokenizer instance
|
||||||
|
|||||||
Reference in New Issue
Block a user