DataParallel for SQuAD + fix XLM
This commit is contained in:
@@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
|
||||
# XLNet and XLM use a more complex post-processing procedure
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
|
||||
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
||||
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
||||
|
||||
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file,
|
||||
model.config.start_n_top, model.config.end_n_top,
|
||||
start_n_top, end_n_top,
|
||||
args.version_2_with_negative, tokenizer, args.verbose_logging)
|
||||
else:
|
||||
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
|
||||
|
||||
@@ -695,7 +695,12 @@ def compute_predictions_log_probs(
|
||||
tok_text = " ".join(tok_text.split())
|
||||
orig_text = " ".join(orig_tokens)
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
|
||||
if hasattr(tokenizer, "do_lower_case"):
|
||||
do_lower_case = tokenizer.do_lower_case
|
||||
else:
|
||||
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
||||
|
||||
final_text = get_final_text(tok_text, orig_text, do_lower_case,
|
||||
verbose_logging)
|
||||
|
||||
if final_text in seen_predictions:
|
||||
|
||||
@@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer):
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
|
||||
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
|
||||
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
|
||||
|
||||
# cache of sm.MosesPunctNormalizer instance
|
||||
self.cache_moses_punct_normalizer = dict()
|
||||
# cache of sm.MosesTokenizer instance
|
||||
|
||||
Reference in New Issue
Block a user