fix tiktoken convert to pass AddedToken to Tokenizer (#36566)
* pass AddedToken to Tokenizer * ruff * handle dict for special tokens * option: test tokenizer from tiktoken same as fast * ruff * ruff
This commit is contained in:
@@ -1580,7 +1580,9 @@ class TikTokenConverter:
|
||||
self.vocab_file = vocab_file
|
||||
self.pattern = pattern
|
||||
self.add_prefix_space = add_prefix_space
|
||||
self.additional_special_tokens = additional_special_tokens
|
||||
self.additional_special_tokens = (
|
||||
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
|
||||
)
|
||||
|
||||
def extract_vocab_merges_from_model(self, tiktoken_url: str):
|
||||
try:
|
||||
@@ -1629,7 +1631,10 @@ class TikTokenConverter:
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.ByteLevel()
|
||||
tokenizer.add_special_tokens(self.additional_special_tokens)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
[AddedToken(token, normalized=False, special=True) for token in self.additional_special_tokens]
|
||||
)
|
||||
|
||||
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_jinja, require_tokenizers
|
||||
from transformers.testing_utils import require_jinja, require_tiktoken, require_tokenizers
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
@@ -299,6 +299,23 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
|
||||
self.assertListEqual(tokenized_chat, expected_tokens)
|
||||
|
||||
@require_tiktoken
|
||||
def test_tokenization_tiktoken(self):
|
||||
from tiktoken import encoding_name_for_model
|
||||
|
||||
from transformers.integrations.tiktoken import convert_tiktoken_to_fast
|
||||
|
||||
encoding = encoding_name_for_model("gpt2")
|
||||
convert_tiktoken_to_fast(encoding, self.tmpdirname)
|
||||
|
||||
tiktoken_fast_tokenizer = GPT2TokenizerFast.from_pretrained(self.tmpdirname)
|
||||
rust_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
sequence = "lower newer"
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(rust_tokenizer.encode(sequence)),
|
||||
tiktoken_fast_tokenizer.decode(rust_tokenizer.encode(sequence)),
|
||||
)
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class OPTTokenizationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user