From a67e7478894c477822760ad7f9933a7d78aa27bd Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 14 Nov 2019 10:30:22 -0500 Subject: [PATCH] Reorganized max_len warning --- transformers/tokenization_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index cd14cc4582..c5f469800b 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -671,10 +671,6 @@ class PreTrainedTokenizer(object): ids = [] for token in tokens: ids.append(self._convert_token_to_id_with_added_voc(token)) - if len(ids) > self.max_len: - logger.warning("Token indices sequence length is longer than the specified maximum sequence length " - "for this model ({} > {}). Running this sequence through the model will result in " - "indexing errors".format(len(ids), self.max_len)) return ids def _convert_token_to_id_with_added_voc(self, token): @@ -877,6 +873,11 @@ class PreTrainedTokenizer(object): encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] + if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: + logger.warning("Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.max_len)) + return encoded_inputs def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):