From 6ee1474b67b088829555364a14ebfb45e661fac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADt=20Novotn=C3=BD?= Date: Tue, 31 May 2022 16:36:45 +0200 Subject: [PATCH] Accumulate tokens into batches in `PreTrainedTokenizerBase.add_tokens()` (#17119) * Accumulate tokens into batches in PreTrainedTokenizerBase.add_tokens() For tokenizers with a small number of special tokens or special tokens with consecutive token IDs, this reduces the time complexity of creating the trie from quadratic to linear, see also #16936. * Extend explanation of batching added tokens --- src/transformers/tokenization_utils_base.py | 24 ++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c127c19f1f..4bd0fab75a 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1964,24 +1964,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Sort added tokens by index added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1])) + # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for + # individual tokens would repeatedly rebuild a trie, which can be slow. + is_last_special = None + tokens = [] + for token, index in added_tok_encoder_sorted: - if has_tokenizer_file and index != len(tokenizer) and tokenizer.convert_tokens_to_ids(token) != index: + current_index = len(tokenizer) + len(tokens) + if has_tokenizer_file and index != current_index and tokenizer.convert_tokens_to_ids(token) != index: # Tokenizer fast: added token needs to either be in the vocabulary with the proper index or the # index is the current length of the tokenizer (not in vocabulary) raise ValueError( f"Wrong index found for {token}: should be {tokenizer.convert_tokens_to_ids(token)} but found " f"{index}." ) - elif not has_tokenizer_file and index != len(tokenizer): + elif not has_tokenizer_file and index != current_index: # Tokenizer slow: added token cannot already be in the vocabulary so its index needs to be the # current length of the tokenizer. raise ValueError( f"Non-consecutive added token '{token}' found. " - f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary." + f"Should have index {current_index} but has index {index} in saved vocabulary." ) - # Safe to call on a tokenizer fast even if token already there. - tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens)) + is_special = bool(token in special_tokens) + if is_last_special is None or is_last_special == is_special: + tokens.append(token) + else: + tokenizer.add_tokens(tokens, special_tokens=is_last_special) + tokens = [token] + is_last_special = is_special + + if tokens: + tokenizer.add_tokens(tokens, special_tokens=is_last_special) # Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab added_tokens = tokenizer.sanitize_special_tokens()