[LlamaTokenizerFast] nit update post_processor on the fly (#23855)
* Update the processor when changing add_eos and add_bos * fixup * update * add a test * fix failing tests * fixup
This commit is contained in:
@@ -16,6 +16,8 @@ import os
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from tokenizers import processors
|
||||
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import is_sentencepiece_available, logging
|
||||
from ...utils.versions import require_version
|
||||
@@ -84,6 +86,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -95,10 +99,50 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
eos_token=eos_token,
|
||||
**kwargs,
|
||||
)
|
||||
self._add_bos_token = add_bos_token
|
||||
self._add_eos_token = add_eos_token
|
||||
self.update_post_processor()
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def update_post_processor(self):
|
||||
bos = self.bos_token
|
||||
bos_token_id = self.bos_token_id
|
||||
|
||||
eos = self.eos_token
|
||||
eos_token_id = self.eos_token_id
|
||||
|
||||
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
|
||||
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
|
||||
|
||||
special_tokens = []
|
||||
if self.add_bos_token:
|
||||
special_tokens.append((bos, bos_token_id))
|
||||
if self.add_eos_token:
|
||||
special_tokens.append((eos, eos_token_id))
|
||||
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
single=single, pair=pair, special_tokens=special_tokens
|
||||
)
|
||||
|
||||
@property
|
||||
def add_eos_token(self):
|
||||
return self._add_eos_token
|
||||
|
||||
@property
|
||||
def add_bos_token(self):
|
||||
return self._add_bos_token
|
||||
|
||||
@add_eos_token.setter
|
||||
def add_eos_token(self, value):
|
||||
self._add_eos_token = value
|
||||
self.update_post_processor()
|
||||
|
||||
@add_bos_token.setter
|
||||
def add_bos_token(self, value):
|
||||
self._add_bos_token = value
|
||||
self.update_post_processor()
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
|
||||
@@ -315,6 +315,39 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_fast_special_tokens(self):
|
||||
slow_tokenizer = self.tokenizer
|
||||
fast_tokenizer = self.rust_tokenizer
|
||||
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [1, 319, 4559, 1243]
|
||||
|
||||
fast_tokenizer.add_eos_token = False
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [1, 319, 4559, 1243]
|
||||
|
||||
fast_tokenizer.add_eos_token = True
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [1, 319, 4559, 1243, 2]
|
||||
|
||||
slow_tokenizer.add_eos_token = True
|
||||
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [1, 319, 4559, 1243, 2]
|
||||
|
||||
fast_tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
||||
)
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [319, 4559, 1243, 2]
|
||||
|
||||
slow_tokenzier = LlamaTokenizer.from_pretrained(
|
||||
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
||||
)
|
||||
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
|
||||
assert slow == [319, 4559, 1243, 2]
|
||||
|
||||
self.tokenizer.add_eos_token = False
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
@slow
|
||||
def test_conversion(self):
|
||||
# This is excruciatingly slow since it has to recreate the entire merge
|
||||
|
||||
Reference in New Issue
Block a user