diff --git a/examples/run_squad.py b/examples/run_squad.py index 2df29014ef..5e3f9663e2 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""): # XLNet and XLM use a more complex post-processing procedure if args.model_type in ['xlnet', 'xlm']: + + start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top + end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top + predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size, args.max_answer_length, output_prediction_file, output_nbest_file, output_null_log_odds_file, - model.config.start_n_top, model.config.end_n_top, + start_n_top, end_n_top, args.version_2_with_negative, tokenizer, args.verbose_logging) else: predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, diff --git a/transformers/data/metrics/squad_metrics.py b/transformers/data/metrics/squad_metrics.py index 0755c0ab7a..7b03255f49 100644 --- a/transformers/data/metrics/squad_metrics.py +++ b/transformers/data/metrics/squad_metrics.py @@ -695,7 +695,12 @@ def compute_predictions_log_probs( tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) - final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, + if hasattr(tokenizer, "do_lower_case"): + do_lower_case = tokenizer.do_lower_case + else: + do_lower_case = tokenizer.do_lowercase_and_remove_accent + + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) if final_text in seen_predictions: diff --git a/transformers/tokenization_xlm.py b/transformers/tokenization_xlm.py index 6c9f8e5e5c..8def80bec4 100644 --- a/transformers/tokenization_xlm.py +++ b/transformers/tokenization_xlm.py @@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer): additional_special_tokens=additional_special_tokens, **kwargs) + + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens + # cache of sm.MosesPunctNormalizer instance self.cache_moses_punct_normalizer = dict() # cache of sm.MosesTokenizer instance