From 6e2d04e429dc4ce240c99bd14b7b84550b79fd73 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 30 Jul 2024 23:36:38 +0200 Subject: [PATCH] Fix slow GemmaTokenizer and improve SPM slow -> fast conversion process (#32191) * Remove user-defined tokens which can be obtained through merges * Remove debug line * formatting * Refactor spm slow -> fast converter * revert unnecessary refactor * set comprehension * remove test files * Use `vocab_scores` * Always replace spiece underline with space in decode * we no longer need token filtering * Add save fast load slow unit test * Remove tokenizers version check * Remove duplicate code * Make `` and `` special tokens * Bias merge priority with length if score is the same * Add unit test for merge priority * CI --- src/transformers/convert_slow_tokenizer.py | 234 +++++++----------- .../models/gemma/tokenization_gemma.py | 2 +- tests/models/gemma/test_tokenization_gemma.py | 35 +++ 3 files changed, 120 insertions(+), 151 deletions(-) diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 305881288e..2d0302d3f6 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str: return prepend_scheme +def generate_merges(vocab, vocab_scores): + reverse = vocab_scores is not None + vocab_scores = dict(vocab_scores) if reverse else vocab + + merges = [] + for merge, piece_score in vocab_scores.items(): + local = [] + for index in range(1, len(merge)): + piece_l, piece_r = merge[:index], merge[index:] + if piece_l in vocab and piece_r in vocab: + local.append((piece_l, piece_r, piece_score)) + local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) + merges.extend(local) + + merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse) + merges = [(val[0], val[1]) for val in merges] + return merges + + class SentencePieceExtractor: """ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece @@ -73,24 +92,8 @@ class SentencePieceExtractor: sp = self.sp vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())} - if vocab_scores is not None: - vocab_scores, reverse = dict(vocab_scores), True - else: - vocab_scores, reverse = vocab, False + merges = generate_merges(vocab, vocab_scores) - # Merges - merges = [] - for merge, piece_score in vocab_scores.items(): - local = [] - for index in range(1, len(merge)): - piece_l, piece_r = merge[:index], merge[index:] - if piece_l in vocab and piece_r in vocab: - local.append((piece_l, piece_r, piece_score)) - local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) - merges.extend(local) - - merges = sorted(merges, key=lambda val: val[2], reverse=reverse) - merges = [(val[0], val[1]) for val in merges] return vocab, merges @@ -107,24 +110,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor): # "<0x09>" is the bytefallback for `\t` vocab["\t"] = vocab.get("<0x09>") - if vocab_scores is not None: - vocab_scores, reverse = dict(vocab_scores), True - else: - vocab_scores, reverse = vocab, False - - # Merges - merges = [] - for merge, piece_score in vocab_scores.items(): - local = [] - for index in range(1, len(merge)): - piece_l, piece_r = merge[:index], merge[index:] - if piece_l in vocab and piece_r in vocab: - local.append((piece_l, piece_r, piece_score)) - local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) - merges.extend(local) - - merges = sorted(merges, key=lambda val: val[2], reverse=reverse) - merges = [(val[0], val[1]) for val in merges] + merges = generate_merges(vocab, vocab_scores) return vocab, merges @@ -544,6 +530,10 @@ class DebertaConverter(Converter): class SpmConverter(Converter): + handle_byte_fallback = False + SpmExtractor = SentencePieceExtractor + special_tokens = {} + def __init__(self, *args): requires_backends(self, "protobuf") @@ -557,14 +547,13 @@ class SpmConverter(Converter): m.ParseFromString(f.read()) self.proto = m - if self.proto.trainer_spec.byte_fallback: - if not getattr(self, "handle_byte_fallback", None): - warnings.warn( - "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" - " which is not implemented in the fast tokenizers. In practice this means that the fast version of the" - " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these " - "unknown tokens into a sequence of byte tokens matching the original piece of text." - ) + if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback: + warnings.warn( + "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" + " which is not implemented in the fast tokenizers. In practice this means that the fast version of the" + " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these " + "unknown tokens into a sequence of byte tokens matching the original piece of text." + ) def vocab(self, proto): return [(piece.piece, piece.score) for piece in proto.pieces] @@ -575,12 +564,18 @@ class SpmConverter(Converter): def tokenizer(self, proto): model_type = proto.trainer_spec.model_type vocab_scores = self.vocab(proto) - unk_id = self.unk_id(proto) if model_type == 1: - tokenizer = Tokenizer(Unigram(vocab_scores, unk_id)) + tokenizer = Tokenizer( + Unigram( + vocab_scores, + unk_id=self.unk_id(proto), + byte_fallback=self.handle_byte_fallback, + ) + ) + elif model_type == 2: - _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract() + _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)} tokenizer = Tokenizer( BPE( @@ -588,13 +583,53 @@ class SpmConverter(Converter): merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, + byte_fallback=self.handle_byte_fallback, + dropout=None, ) ) + else: raise Exception( "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" ) + # control tokens are special + # user defined symbols are not + # both user and control tokens are AddedTokens + # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33) + spm_added_tokens = [ + (id, p.piece, p.type == 3 or p.piece in self.special_tokens) + for id, p in enumerate(proto.pieces) + if p.type in [3, 4] + ] + tokens_to_add = [ + AddedToken(token, normalized=False, special=special) + for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0]) + ] + + if len(tokens_to_add) > 0: + # super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ + # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for + # individual tokens would repeatedly rebuild a trie, which can be slow. + is_last_special = None + tokens = [] + for token in tokens_to_add: + is_special = token.special + if is_last_special is None or is_last_special == is_special: + tokens.append(token) + else: + if is_last_special: + tokenizer.add_special_tokens(tokens) + else: + tokenizer.add_tokens(tokens) + tokens = [token] + is_last_special = is_special + if tokens: + if is_last_special: + tokenizer.add_special_tokens(tokens) + else: + tokenizer.add_tokens(tokens) + return tokenizer def normalizer(self, proto): @@ -622,40 +657,6 @@ class SpmConverter(Converter): def converted(self) -> Tokenizer: tokenizer = self.tokenizer(self.proto) - # control tokens are special - # user defined symbols are not - # both user and control tokens are AddedTokens - # Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33) - - tokens_to_add = { - id: AddedToken(token, normalized=False, special=special) - for id, token, special in [ - (id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4] - ] - } - tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])] - if len(tokens_to_add) > 0: - # super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ - # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for - # individual tokens would repeatedly rebuild a trie, which can be slow. - is_last_special = None - tokens = [] - for token in tokens_to_add: - is_special = token.special - if is_last_special is None or is_last_special == is_special: - tokens.append(token) - else: - if is_last_special: - tokenizer.add_special_tokens(tokens) - else: - tokenizer.add_tokens(tokens) - tokens = [token] - is_last_special = is_special - if tokens: - if is_last_special: - tokenizer.add_special_tokens(tokens) - else: - tokenizer.add_tokens(tokens) # Tokenizer assemble normalizer = self.normalizer(self.proto) if normalizer is not None: @@ -1283,6 +1284,9 @@ class XGLMConverter(SpmConverter): class GemmaConvert(SpmConverter): handle_byte_fallback = True + SpmExtractor = GemmaSentencePieceExtractor + # start and end of turn tokens must be marked as special + special_tokens = {"", ""} """" split_by_unicode_script: true @@ -1327,45 +1331,6 @@ class GemmaConvert(SpmConverter): ] ) - def tokenizer(self, proto): - model_type = proto.trainer_spec.model_type - vocab_scores = self.vocab(proto) - if model_type == 1: - import tokenizers - - if version.parse(tokenizers.__version__) < version.parse("0.14.0"): - tokenizer = Tokenizer(Unigram(vocab_scores, 0)) - else: - tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True)) - - elif model_type == 2: - _, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) - bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} - - tokenizer = Tokenizer( - BPE( - bpe_vocab, - merges, - unk_token=proto.trainer_spec.unk_piece, - fuse_unk=True, - byte_fallback=True, - dropout=None, - ) - ) - tokenizer.add_special_tokens( - [ - AddedToken("", normalized=False, special=True), - AddedToken("", normalized=False, special=True), - AddedToken("", normalized=False, special=True), - AddedToken("", normalized=False, special=True), - ] - ) - else: - raise Exception( - "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" - ) - return tokenizer - class LlamaConverter(SpmConverter): handle_byte_fallback = True @@ -1393,37 +1358,6 @@ class LlamaConverter(SpmConverter): sequence += [decoders.Strip(content=" ", left=1)] return decoders.Sequence(sequence) - def tokenizer(self, proto): - model_type = proto.trainer_spec.model_type - vocab_scores = self.vocab(proto) - if model_type == 1: - import tokenizers - - if version.parse(tokenizers.__version__) < version.parse("0.14.0"): - tokenizer = Tokenizer(Unigram(vocab_scores, 0)) - else: - tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True)) - - elif model_type == 2: - _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) - bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} - tokenizer = Tokenizer( - BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True) - ) - tokenizer.add_special_tokens( - [ - AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True), - AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True), - AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True), - ] - ) - else: - raise Exception( - "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" - ) - - return tokenizer - def normalizer(self, proto): if getattr(self.original_tokenizer, "legacy", True): sequence = [] diff --git a/src/transformers/models/gemma/tokenization_gemma.py b/src/transformers/models/gemma/tokenization_gemma.py index f70c6e807e..09e779478c 100644 --- a/src/transformers/models/gemma/tokenization_gemma.py +++ b/src/transformers/models/gemma/tokenization_gemma.py @@ -198,7 +198,7 @@ class GemmaTokenizer(PreTrainedTokenizer): else: sub_texts = "".join(sub_texts) - return sub_texts + return sub_texts.replace(SPIECE_UNDERLINE, " ") def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" diff --git a/tests/models/gemma/test_tokenization_gemma.py b/tests/models/gemma/test_tokenization_gemma.py index 4201e31e6f..657c84aaa0 100644 --- a/tests/models/gemma/test_tokenization_gemma.py +++ b/tests/models/gemma/test_tokenization_gemma.py @@ -222,6 +222,17 @@ class GemmaIntegrationTest(unittest.TestCase): self.tokenizer.add_eos_token = False self.rust_tokenizer.add_eos_token = False + def test_fast_merge_priority(self): + slow_tokenizer = self.tokenizer + fast_tokenizer = self.rust_tokenizer + text = " " + target = [168, 153] + slow = slow_tokenizer.encode(text, add_special_tokens=False) + assert slow == target + + fast = fast_tokenizer.encode(text, add_special_tokens=False) + assert fast == target + @unittest.skip(reason="Not super important and always failing. Let's skip it") @slow def test_conversion(self): @@ -442,6 +453,30 @@ class GemmaIntegrationTest(unittest.TestCase): for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): self.assertListEqual(tokenized_chat, expected_tokens) + def test_save_fast_load_slow(self): + # Ensure that we can save a fast tokenizer and load it as a slow tokenizer + slow_tokenizer = self.tokenizer + text = "a " + target_encoded = [2, 235250, 139] + slow = slow_tokenizer.encode(text, add_special_tokens=True) + assert slow == target_encoded + + slow_decoded = slow_tokenizer.decode(slow, skip_special_tokens=True) + assert slow_decoded == text + + with tempfile.TemporaryDirectory() as dirname: + # Save fast tokenizer + self.rust_tokenizer.save_pretrained(dirname) + + # Load slow tokenizer with fast files present in the directory + slow_tokenizer_from_fast = GemmaTokenizer.from_pretrained(dirname) + + slow_from_fast = slow_tokenizer_from_fast.encode(text, add_special_tokens=True) + assert slow_from_fast == target_encoded + + slow_from_fast_decoded = slow_tokenizer_from_fast.decode(slow, skip_special_tokens=True) + assert slow_from_fast_decoded == text + @require_sentencepiece @require_tokenizers