Fix slow GemmaTokenizer and improve SPM slow -> fast conversion process (#32191)
* Remove user-defined tokens which can be obtained through merges * Remove debug line * formatting * Refactor spm slow -> fast converter * revert unnecessary refactor * set comprehension * remove test files * Use `vocab_scores` * Always replace spiece underline with space in decode * we no longer need token filtering * Add save fast load slow unit test * Remove tokenizers version check * Remove duplicate code * Make `<start_of_turn>` and `<end_of_turn>` special tokens * Bias merge priority with length if score is the same * Add unit test for merge priority * CI
This commit is contained in:
@@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
||||
return prepend_scheme
|
||||
|
||||
|
||||
def generate_merges(vocab, vocab_scores):
|
||||
reverse = vocab_scores is not None
|
||||
vocab_scores = dict(vocab_scores) if reverse else vocab
|
||||
|
||||
merges = []
|
||||
for merge, piece_score in vocab_scores.items():
|
||||
local = []
|
||||
for index in range(1, len(merge)):
|
||||
piece_l, piece_r = merge[:index], merge[index:]
|
||||
if piece_l in vocab and piece_r in vocab:
|
||||
local.append((piece_l, piece_r, piece_score))
|
||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
||||
merges.extend(local)
|
||||
|
||||
merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
|
||||
merges = [(val[0], val[1]) for val in merges]
|
||||
return merges
|
||||
|
||||
|
||||
class SentencePieceExtractor:
|
||||
"""
|
||||
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
|
||||
@@ -73,24 +92,8 @@ class SentencePieceExtractor:
|
||||
sp = self.sp
|
||||
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
||||
|
||||
if vocab_scores is not None:
|
||||
vocab_scores, reverse = dict(vocab_scores), True
|
||||
else:
|
||||
vocab_scores, reverse = vocab, False
|
||||
merges = generate_merges(vocab, vocab_scores)
|
||||
|
||||
# Merges
|
||||
merges = []
|
||||
for merge, piece_score in vocab_scores.items():
|
||||
local = []
|
||||
for index in range(1, len(merge)):
|
||||
piece_l, piece_r = merge[:index], merge[index:]
|
||||
if piece_l in vocab and piece_r in vocab:
|
||||
local.append((piece_l, piece_r, piece_score))
|
||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
||||
merges.extend(local)
|
||||
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
|
||||
merges = [(val[0], val[1]) for val in merges]
|
||||
return vocab, merges
|
||||
|
||||
|
||||
@@ -107,24 +110,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
|
||||
# "<0x09>" is the bytefallback for `\t`
|
||||
vocab["\t"] = vocab.get("<0x09>")
|
||||
|
||||
if vocab_scores is not None:
|
||||
vocab_scores, reverse = dict(vocab_scores), True
|
||||
else:
|
||||
vocab_scores, reverse = vocab, False
|
||||
|
||||
# Merges
|
||||
merges = []
|
||||
for merge, piece_score in vocab_scores.items():
|
||||
local = []
|
||||
for index in range(1, len(merge)):
|
||||
piece_l, piece_r = merge[:index], merge[index:]
|
||||
if piece_l in vocab and piece_r in vocab:
|
||||
local.append((piece_l, piece_r, piece_score))
|
||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
||||
merges.extend(local)
|
||||
|
||||
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
|
||||
merges = [(val[0], val[1]) for val in merges]
|
||||
merges = generate_merges(vocab, vocab_scores)
|
||||
return vocab, merges
|
||||
|
||||
|
||||
@@ -544,6 +530,10 @@ class DebertaConverter(Converter):
|
||||
|
||||
|
||||
class SpmConverter(Converter):
|
||||
handle_byte_fallback = False
|
||||
SpmExtractor = SentencePieceExtractor
|
||||
special_tokens = {}
|
||||
|
||||
def __init__(self, *args):
|
||||
requires_backends(self, "protobuf")
|
||||
|
||||
@@ -557,8 +547,7 @@ class SpmConverter(Converter):
|
||||
m.ParseFromString(f.read())
|
||||
self.proto = m
|
||||
|
||||
if self.proto.trainer_spec.byte_fallback:
|
||||
if not getattr(self, "handle_byte_fallback", None):
|
||||
if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
|
||||
warnings.warn(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||
@@ -575,12 +564,18 @@ class SpmConverter(Converter):
|
||||
def tokenizer(self, proto):
|
||||
model_type = proto.trainer_spec.model_type
|
||||
vocab_scores = self.vocab(proto)
|
||||
unk_id = self.unk_id(proto)
|
||||
|
||||
if model_type == 1:
|
||||
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
|
||||
tokenizer = Tokenizer(
|
||||
Unigram(
|
||||
vocab_scores,
|
||||
unk_id=self.unk_id(proto),
|
||||
byte_fallback=self.handle_byte_fallback,
|
||||
)
|
||||
)
|
||||
|
||||
elif model_type == 2:
|
||||
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
||||
_, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
@@ -588,13 +583,53 @@ class SpmConverter(Converter):
|
||||
merges,
|
||||
unk_token=proto.trainer_spec.unk_piece,
|
||||
fuse_unk=True,
|
||||
byte_fallback=self.handle_byte_fallback,
|
||||
dropout=None,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
)
|
||||
|
||||
# control tokens are special
|
||||
# user defined symbols are not
|
||||
# both user and control tokens are AddedTokens
|
||||
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
||||
spm_added_tokens = [
|
||||
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
|
||||
for id, p in enumerate(proto.pieces)
|
||||
if p.type in [3, 4]
|
||||
]
|
||||
tokens_to_add = [
|
||||
AddedToken(token, normalized=False, special=special)
|
||||
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
|
||||
]
|
||||
|
||||
if len(tokens_to_add) > 0:
|
||||
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
|
||||
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
|
||||
# individual tokens would repeatedly rebuild a trie, which can be slow.
|
||||
is_last_special = None
|
||||
tokens = []
|
||||
for token in tokens_to_add:
|
||||
is_special = token.special
|
||||
if is_last_special is None or is_last_special == is_special:
|
||||
tokens.append(token)
|
||||
else:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
tokens = [token]
|
||||
is_last_special = is_special
|
||||
if tokens:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def normalizer(self, proto):
|
||||
@@ -622,40 +657,6 @@ class SpmConverter(Converter):
|
||||
def converted(self) -> Tokenizer:
|
||||
tokenizer = self.tokenizer(self.proto)
|
||||
|
||||
# control tokens are special
|
||||
# user defined symbols are not
|
||||
# both user and control tokens are AddedTokens
|
||||
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
||||
|
||||
tokens_to_add = {
|
||||
id: AddedToken(token, normalized=False, special=special)
|
||||
for id, token, special in [
|
||||
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
|
||||
]
|
||||
}
|
||||
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
|
||||
if len(tokens_to_add) > 0:
|
||||
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
|
||||
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
|
||||
# individual tokens would repeatedly rebuild a trie, which can be slow.
|
||||
is_last_special = None
|
||||
tokens = []
|
||||
for token in tokens_to_add:
|
||||
is_special = token.special
|
||||
if is_last_special is None or is_last_special == is_special:
|
||||
tokens.append(token)
|
||||
else:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
tokens = [token]
|
||||
is_last_special = is_special
|
||||
if tokens:
|
||||
if is_last_special:
|
||||
tokenizer.add_special_tokens(tokens)
|
||||
else:
|
||||
tokenizer.add_tokens(tokens)
|
||||
# Tokenizer assemble
|
||||
normalizer = self.normalizer(self.proto)
|
||||
if normalizer is not None:
|
||||
@@ -1283,6 +1284,9 @@ class XGLMConverter(SpmConverter):
|
||||
|
||||
class GemmaConvert(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
SpmExtractor = GemmaSentencePieceExtractor
|
||||
# start and end of turn tokens must be marked as special
|
||||
special_tokens = {"<start_of_turn>", "<end_of_turn>"}
|
||||
|
||||
""""
|
||||
split_by_unicode_script: true
|
||||
@@ -1327,45 +1331,6 @@ class GemmaConvert(SpmConverter):
|
||||
]
|
||||
)
|
||||
|
||||
def tokenizer(self, proto):
|
||||
model_type = proto.trainer_spec.model_type
|
||||
vocab_scores = self.vocab(proto)
|
||||
if model_type == 1:
|
||||
import tokenizers
|
||||
|
||||
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
|
||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
|
||||
else:
|
||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
|
||||
|
||||
elif model_type == 2:
|
||||
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
bpe_vocab,
|
||||
merges,
|
||||
unk_token=proto.trainer_spec.unk_piece,
|
||||
fuse_unk=True,
|
||||
byte_fallback=True,
|
||||
dropout=None,
|
||||
)
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
[
|
||||
AddedToken("<pad>", normalized=False, special=True),
|
||||
AddedToken("<eos>", normalized=False, special=True),
|
||||
AddedToken("<bos>", normalized=False, special=True),
|
||||
AddedToken("<unk>", normalized=False, special=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
|
||||
class LlamaConverter(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
@@ -1393,37 +1358,6 @@ class LlamaConverter(SpmConverter):
|
||||
sequence += [decoders.Strip(content=" ", left=1)]
|
||||
return decoders.Sequence(sequence)
|
||||
|
||||
def tokenizer(self, proto):
|
||||
model_type = proto.trainer_spec.model_type
|
||||
vocab_scores = self.vocab(proto)
|
||||
if model_type == 1:
|
||||
import tokenizers
|
||||
|
||||
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
|
||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
|
||||
else:
|
||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
|
||||
|
||||
elif model_type == 2:
|
||||
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
||||
tokenizer = Tokenizer(
|
||||
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
[
|
||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
|
||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
|
||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def normalizer(self, proto):
|
||||
if getattr(self.original_tokenizer, "legacy", True):
|
||||
sequence = []
|
||||
|
||||
@@ -198,7 +198,7 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
sub_texts = "".join(sub_texts)
|
||||
|
||||
return sub_texts
|
||||
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
|
||||
@@ -222,6 +222,17 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
self.tokenizer.add_eos_token = False
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
def test_fast_merge_priority(self):
|
||||
slow_tokenizer = self.tokenizer
|
||||
fast_tokenizer = self.rust_tokenizer
|
||||
text = " "
|
||||
target = [168, 153]
|
||||
slow = slow_tokenizer.encode(text, add_special_tokens=False)
|
||||
assert slow == target
|
||||
|
||||
fast = fast_tokenizer.encode(text, add_special_tokens=False)
|
||||
assert fast == target
|
||||
|
||||
@unittest.skip(reason="Not super important and always failing. Let's skip it")
|
||||
@slow
|
||||
def test_conversion(self):
|
||||
@@ -442,6 +453,30 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
def test_save_fast_load_slow(self):
|
||||
# Ensure that we can save a fast tokenizer and load it as a slow tokenizer
|
||||
slow_tokenizer = self.tokenizer
|
||||
text = "a "
|
||||
target_encoded = [2, 235250, 139]
|
||||
slow = slow_tokenizer.encode(text, add_special_tokens=True)
|
||||
assert slow == target_encoded
|
||||
|
||||
slow_decoded = slow_tokenizer.decode(slow, skip_special_tokens=True)
|
||||
assert slow_decoded == text
|
||||
|
||||
with tempfile.TemporaryDirectory() as dirname:
|
||||
# Save fast tokenizer
|
||||
self.rust_tokenizer.save_pretrained(dirname)
|
||||
|
||||
# Load slow tokenizer with fast files present in the directory
|
||||
slow_tokenizer_from_fast = GemmaTokenizer.from_pretrained(dirname)
|
||||
|
||||
slow_from_fast = slow_tokenizer_from_fast.encode(text, add_special_tokens=True)
|
||||
assert slow_from_fast == target_encoded
|
||||
|
||||
slow_from_fast_decoded = slow_tokenizer_from_fast.decode(slow, skip_special_tokens=True)
|
||||
assert slow_from_fast_decoded == text
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
Reference in New Issue
Block a user