Fix slow GemmaTokenizer and improve SPM slow -> fast conversion process (#32191)
* Remove user-defined tokens which can be obtained through merges * Remove debug line * formatting * Refactor spm slow -> fast converter * revert unnecessary refactor * set comprehension * remove test files * Use `vocab_scores` * Always replace spiece underline with space in decode * we no longer need token filtering * Add save fast load slow unit test * Remove tokenizers version check * Remove duplicate code * Make `<start_of_turn>` and `<end_of_turn>` special tokens * Bias merge priority with length if score is the same * Add unit test for merge priority * CI
This commit is contained in:
@@ -53,6 +53,25 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
|
|||||||
return prepend_scheme
|
return prepend_scheme
|
||||||
|
|
||||||
|
|
||||||
|
def generate_merges(vocab, vocab_scores):
|
||||||
|
reverse = vocab_scores is not None
|
||||||
|
vocab_scores = dict(vocab_scores) if reverse else vocab
|
||||||
|
|
||||||
|
merges = []
|
||||||
|
for merge, piece_score in vocab_scores.items():
|
||||||
|
local = []
|
||||||
|
for index in range(1, len(merge)):
|
||||||
|
piece_l, piece_r = merge[:index], merge[index:]
|
||||||
|
if piece_l in vocab and piece_r in vocab:
|
||||||
|
local.append((piece_l, piece_r, piece_score))
|
||||||
|
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
||||||
|
merges.extend(local)
|
||||||
|
|
||||||
|
merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
|
||||||
|
merges = [(val[0], val[1]) for val in merges]
|
||||||
|
return merges
|
||||||
|
|
||||||
|
|
||||||
class SentencePieceExtractor:
|
class SentencePieceExtractor:
|
||||||
"""
|
"""
|
||||||
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
|
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
|
||||||
@@ -73,24 +92,8 @@ class SentencePieceExtractor:
|
|||||||
sp = self.sp
|
sp = self.sp
|
||||||
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
||||||
|
|
||||||
if vocab_scores is not None:
|
merges = generate_merges(vocab, vocab_scores)
|
||||||
vocab_scores, reverse = dict(vocab_scores), True
|
|
||||||
else:
|
|
||||||
vocab_scores, reverse = vocab, False
|
|
||||||
|
|
||||||
# Merges
|
|
||||||
merges = []
|
|
||||||
for merge, piece_score in vocab_scores.items():
|
|
||||||
local = []
|
|
||||||
for index in range(1, len(merge)):
|
|
||||||
piece_l, piece_r = merge[:index], merge[index:]
|
|
||||||
if piece_l in vocab and piece_r in vocab:
|
|
||||||
local.append((piece_l, piece_r, piece_score))
|
|
||||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
|
||||||
merges.extend(local)
|
|
||||||
|
|
||||||
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
|
|
||||||
merges = [(val[0], val[1]) for val in merges]
|
|
||||||
return vocab, merges
|
return vocab, merges
|
||||||
|
|
||||||
|
|
||||||
@@ -107,24 +110,7 @@ class GemmaSentencePieceExtractor(SentencePieceExtractor):
|
|||||||
# "<0x09>" is the bytefallback for `\t`
|
# "<0x09>" is the bytefallback for `\t`
|
||||||
vocab["\t"] = vocab.get("<0x09>")
|
vocab["\t"] = vocab.get("<0x09>")
|
||||||
|
|
||||||
if vocab_scores is not None:
|
merges = generate_merges(vocab, vocab_scores)
|
||||||
vocab_scores, reverse = dict(vocab_scores), True
|
|
||||||
else:
|
|
||||||
vocab_scores, reverse = vocab, False
|
|
||||||
|
|
||||||
# Merges
|
|
||||||
merges = []
|
|
||||||
for merge, piece_score in vocab_scores.items():
|
|
||||||
local = []
|
|
||||||
for index in range(1, len(merge)):
|
|
||||||
piece_l, piece_r = merge[:index], merge[index:]
|
|
||||||
if piece_l in vocab and piece_r in vocab:
|
|
||||||
local.append((piece_l, piece_r, piece_score))
|
|
||||||
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
|
||||||
merges.extend(local)
|
|
||||||
|
|
||||||
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
|
|
||||||
merges = [(val[0], val[1]) for val in merges]
|
|
||||||
return vocab, merges
|
return vocab, merges
|
||||||
|
|
||||||
|
|
||||||
@@ -544,6 +530,10 @@ class DebertaConverter(Converter):
|
|||||||
|
|
||||||
|
|
||||||
class SpmConverter(Converter):
|
class SpmConverter(Converter):
|
||||||
|
handle_byte_fallback = False
|
||||||
|
SpmExtractor = SentencePieceExtractor
|
||||||
|
special_tokens = {}
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
requires_backends(self, "protobuf")
|
requires_backends(self, "protobuf")
|
||||||
|
|
||||||
@@ -557,14 +547,13 @@ class SpmConverter(Converter):
|
|||||||
m.ParseFromString(f.read())
|
m.ParseFromString(f.read())
|
||||||
self.proto = m
|
self.proto = m
|
||||||
|
|
||||||
if self.proto.trainer_spec.byte_fallback:
|
if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
|
||||||
if not getattr(self, "handle_byte_fallback", None):
|
warnings.warn(
|
||||||
warnings.warn(
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||||
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||||
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def vocab(self, proto):
|
def vocab(self, proto):
|
||||||
return [(piece.piece, piece.score) for piece in proto.pieces]
|
return [(piece.piece, piece.score) for piece in proto.pieces]
|
||||||
@@ -575,12 +564,18 @@ class SpmConverter(Converter):
|
|||||||
def tokenizer(self, proto):
|
def tokenizer(self, proto):
|
||||||
model_type = proto.trainer_spec.model_type
|
model_type = proto.trainer_spec.model_type
|
||||||
vocab_scores = self.vocab(proto)
|
vocab_scores = self.vocab(proto)
|
||||||
unk_id = self.unk_id(proto)
|
|
||||||
|
|
||||||
if model_type == 1:
|
if model_type == 1:
|
||||||
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
|
tokenizer = Tokenizer(
|
||||||
|
Unigram(
|
||||||
|
vocab_scores,
|
||||||
|
unk_id=self.unk_id(proto),
|
||||||
|
byte_fallback=self.handle_byte_fallback,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
elif model_type == 2:
|
elif model_type == 2:
|
||||||
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
_, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
||||||
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
||||||
tokenizer = Tokenizer(
|
tokenizer = Tokenizer(
|
||||||
BPE(
|
BPE(
|
||||||
@@ -588,13 +583,53 @@ class SpmConverter(Converter):
|
|||||||
merges,
|
merges,
|
||||||
unk_token=proto.trainer_spec.unk_piece,
|
unk_token=proto.trainer_spec.unk_piece,
|
||||||
fuse_unk=True,
|
fuse_unk=True,
|
||||||
|
byte_fallback=self.handle_byte_fallback,
|
||||||
|
dropout=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# control tokens are special
|
||||||
|
# user defined symbols are not
|
||||||
|
# both user and control tokens are AddedTokens
|
||||||
|
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
||||||
|
spm_added_tokens = [
|
||||||
|
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
|
||||||
|
for id, p in enumerate(proto.pieces)
|
||||||
|
if p.type in [3, 4]
|
||||||
|
]
|
||||||
|
tokens_to_add = [
|
||||||
|
AddedToken(token, normalized=False, special=special)
|
||||||
|
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(tokens_to_add) > 0:
|
||||||
|
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
|
||||||
|
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
|
||||||
|
# individual tokens would repeatedly rebuild a trie, which can be slow.
|
||||||
|
is_last_special = None
|
||||||
|
tokens = []
|
||||||
|
for token in tokens_to_add:
|
||||||
|
is_special = token.special
|
||||||
|
if is_last_special is None or is_last_special == is_special:
|
||||||
|
tokens.append(token)
|
||||||
|
else:
|
||||||
|
if is_last_special:
|
||||||
|
tokenizer.add_special_tokens(tokens)
|
||||||
|
else:
|
||||||
|
tokenizer.add_tokens(tokens)
|
||||||
|
tokens = [token]
|
||||||
|
is_last_special = is_special
|
||||||
|
if tokens:
|
||||||
|
if is_last_special:
|
||||||
|
tokenizer.add_special_tokens(tokens)
|
||||||
|
else:
|
||||||
|
tokenizer.add_tokens(tokens)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def normalizer(self, proto):
|
def normalizer(self, proto):
|
||||||
@@ -622,40 +657,6 @@ class SpmConverter(Converter):
|
|||||||
def converted(self) -> Tokenizer:
|
def converted(self) -> Tokenizer:
|
||||||
tokenizer = self.tokenizer(self.proto)
|
tokenizer = self.tokenizer(self.proto)
|
||||||
|
|
||||||
# control tokens are special
|
|
||||||
# user defined symbols are not
|
|
||||||
# both user and control tokens are AddedTokens
|
|
||||||
# Add user defined symbols (type == 4) from sentnecepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
|
|
||||||
|
|
||||||
tokens_to_add = {
|
|
||||||
id: AddedToken(token, normalized=False, special=special)
|
|
||||||
for id, token, special in [
|
|
||||||
(id, p.piece, p.type == 3) for id, p in enumerate(self.proto.pieces) if p.type in [3, 4]
|
|
||||||
]
|
|
||||||
}
|
|
||||||
tokens_to_add = [k for _, k in sorted(tokens_to_add.items(), key=lambda x: x[0])]
|
|
||||||
if len(tokens_to_add) > 0:
|
|
||||||
# super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
|
|
||||||
# Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
|
|
||||||
# individual tokens would repeatedly rebuild a trie, which can be slow.
|
|
||||||
is_last_special = None
|
|
||||||
tokens = []
|
|
||||||
for token in tokens_to_add:
|
|
||||||
is_special = token.special
|
|
||||||
if is_last_special is None or is_last_special == is_special:
|
|
||||||
tokens.append(token)
|
|
||||||
else:
|
|
||||||
if is_last_special:
|
|
||||||
tokenizer.add_special_tokens(tokens)
|
|
||||||
else:
|
|
||||||
tokenizer.add_tokens(tokens)
|
|
||||||
tokens = [token]
|
|
||||||
is_last_special = is_special
|
|
||||||
if tokens:
|
|
||||||
if is_last_special:
|
|
||||||
tokenizer.add_special_tokens(tokens)
|
|
||||||
else:
|
|
||||||
tokenizer.add_tokens(tokens)
|
|
||||||
# Tokenizer assemble
|
# Tokenizer assemble
|
||||||
normalizer = self.normalizer(self.proto)
|
normalizer = self.normalizer(self.proto)
|
||||||
if normalizer is not None:
|
if normalizer is not None:
|
||||||
@@ -1283,6 +1284,9 @@ class XGLMConverter(SpmConverter):
|
|||||||
|
|
||||||
class GemmaConvert(SpmConverter):
|
class GemmaConvert(SpmConverter):
|
||||||
handle_byte_fallback = True
|
handle_byte_fallback = True
|
||||||
|
SpmExtractor = GemmaSentencePieceExtractor
|
||||||
|
# start and end of turn tokens must be marked as special
|
||||||
|
special_tokens = {"<start_of_turn>", "<end_of_turn>"}
|
||||||
|
|
||||||
""""
|
""""
|
||||||
split_by_unicode_script: true
|
split_by_unicode_script: true
|
||||||
@@ -1327,45 +1331,6 @@ class GemmaConvert(SpmConverter):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def tokenizer(self, proto):
|
|
||||||
model_type = proto.trainer_spec.model_type
|
|
||||||
vocab_scores = self.vocab(proto)
|
|
||||||
if model_type == 1:
|
|
||||||
import tokenizers
|
|
||||||
|
|
||||||
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
|
|
||||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
|
|
||||||
else:
|
|
||||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
|
|
||||||
|
|
||||||
elif model_type == 2:
|
|
||||||
_, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
|
||||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(
|
|
||||||
BPE(
|
|
||||||
bpe_vocab,
|
|
||||||
merges,
|
|
||||||
unk_token=proto.trainer_spec.unk_piece,
|
|
||||||
fuse_unk=True,
|
|
||||||
byte_fallback=True,
|
|
||||||
dropout=None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
[
|
|
||||||
AddedToken("<pad>", normalized=False, special=True),
|
|
||||||
AddedToken("<eos>", normalized=False, special=True),
|
|
||||||
AddedToken("<bos>", normalized=False, special=True),
|
|
||||||
AddedToken("<unk>", normalized=False, special=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
|
||||||
)
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaConverter(SpmConverter):
|
class LlamaConverter(SpmConverter):
|
||||||
handle_byte_fallback = True
|
handle_byte_fallback = True
|
||||||
@@ -1393,37 +1358,6 @@ class LlamaConverter(SpmConverter):
|
|||||||
sequence += [decoders.Strip(content=" ", left=1)]
|
sequence += [decoders.Strip(content=" ", left=1)]
|
||||||
return decoders.Sequence(sequence)
|
return decoders.Sequence(sequence)
|
||||||
|
|
||||||
def tokenizer(self, proto):
|
|
||||||
model_type = proto.trainer_spec.model_type
|
|
||||||
vocab_scores = self.vocab(proto)
|
|
||||||
if model_type == 1:
|
|
||||||
import tokenizers
|
|
||||||
|
|
||||||
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
|
|
||||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
|
|
||||||
else:
|
|
||||||
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
|
|
||||||
|
|
||||||
elif model_type == 2:
|
|
||||||
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
|
||||||
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
|
||||||
tokenizer = Tokenizer(
|
|
||||||
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
|
|
||||||
)
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
[
|
|
||||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
|
|
||||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
|
|
||||||
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
|
||||||
)
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def normalizer(self, proto):
|
def normalizer(self, proto):
|
||||||
if getattr(self.original_tokenizer, "legacy", True):
|
if getattr(self.original_tokenizer, "legacy", True):
|
||||||
sequence = []
|
sequence = []
|
||||||
|
|||||||
@@ -198,7 +198,7 @@ class GemmaTokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
sub_texts = "".join(sub_texts)
|
sub_texts = "".join(sub_texts)
|
||||||
|
|
||||||
return sub_texts
|
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of tokens (string) in a single string."""
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
|||||||
@@ -222,6 +222,17 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
self.tokenizer.add_eos_token = False
|
self.tokenizer.add_eos_token = False
|
||||||
self.rust_tokenizer.add_eos_token = False
|
self.rust_tokenizer.add_eos_token = False
|
||||||
|
|
||||||
|
def test_fast_merge_priority(self):
|
||||||
|
slow_tokenizer = self.tokenizer
|
||||||
|
fast_tokenizer = self.rust_tokenizer
|
||||||
|
text = " "
|
||||||
|
target = [168, 153]
|
||||||
|
slow = slow_tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
assert slow == target
|
||||||
|
|
||||||
|
fast = fast_tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
assert fast == target
|
||||||
|
|
||||||
@unittest.skip(reason="Not super important and always failing. Let's skip it")
|
@unittest.skip(reason="Not super important and always failing. Let's skip it")
|
||||||
@slow
|
@slow
|
||||||
def test_conversion(self):
|
def test_conversion(self):
|
||||||
@@ -442,6 +453,30 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||||
|
|
||||||
|
def test_save_fast_load_slow(self):
|
||||||
|
# Ensure that we can save a fast tokenizer and load it as a slow tokenizer
|
||||||
|
slow_tokenizer = self.tokenizer
|
||||||
|
text = "a "
|
||||||
|
target_encoded = [2, 235250, 139]
|
||||||
|
slow = slow_tokenizer.encode(text, add_special_tokens=True)
|
||||||
|
assert slow == target_encoded
|
||||||
|
|
||||||
|
slow_decoded = slow_tokenizer.decode(slow, skip_special_tokens=True)
|
||||||
|
assert slow_decoded == text
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as dirname:
|
||||||
|
# Save fast tokenizer
|
||||||
|
self.rust_tokenizer.save_pretrained(dirname)
|
||||||
|
|
||||||
|
# Load slow tokenizer with fast files present in the directory
|
||||||
|
slow_tokenizer_from_fast = GemmaTokenizer.from_pretrained(dirname)
|
||||||
|
|
||||||
|
slow_from_fast = slow_tokenizer_from_fast.encode(text, add_special_tokens=True)
|
||||||
|
assert slow_from_fast == target_encoded
|
||||||
|
|
||||||
|
slow_from_fast_decoded = slow_tokenizer_from_fast.decode(slow, skip_special_tokens=True)
|
||||||
|
assert slow_from_fast_decoded == text
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
|
|||||||
Reference in New Issue
Block a user