[LlamaTokenizerFast] nit update post_processor on the fly (#23855)
* Update the processor when changing add_eos and add_bos * fixup * update * add a test * fix failing tests * fixup
This commit is contained in:
@@ -16,6 +16,8 @@ import os
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import is_sentencepiece_available, logging
|
from ...utils import is_sentencepiece_available, logging
|
||||||
from ...utils.versions import require_version
|
from ...utils.versions import require_version
|
||||||
@@ -84,6 +86,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
unk_token="<unk>",
|
unk_token="<unk>",
|
||||||
bos_token="<s>",
|
bos_token="<s>",
|
||||||
eos_token="</s>",
|
eos_token="</s>",
|
||||||
|
add_bos_token=True,
|
||||||
|
add_eos_token=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -95,10 +99,50 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
self._add_bos_token = add_bos_token
|
||||||
|
self._add_eos_token = add_eos_token
|
||||||
|
self.update_post_processor()
|
||||||
|
|
||||||
self.vocab_file = vocab_file
|
self.vocab_file = vocab_file
|
||||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||||
|
|
||||||
|
def update_post_processor(self):
|
||||||
|
bos = self.bos_token
|
||||||
|
bos_token_id = self.bos_token_id
|
||||||
|
|
||||||
|
eos = self.eos_token
|
||||||
|
eos_token_id = self.eos_token_id
|
||||||
|
|
||||||
|
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
|
||||||
|
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
|
||||||
|
|
||||||
|
special_tokens = []
|
||||||
|
if self.add_bos_token:
|
||||||
|
special_tokens.append((bos, bos_token_id))
|
||||||
|
if self.add_eos_token:
|
||||||
|
special_tokens.append((eos, eos_token_id))
|
||||||
|
self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||||
|
single=single, pair=pair, special_tokens=special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def add_eos_token(self):
|
||||||
|
return self._add_eos_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def add_bos_token(self):
|
||||||
|
return self._add_bos_token
|
||||||
|
|
||||||
|
@add_eos_token.setter
|
||||||
|
def add_eos_token(self, value):
|
||||||
|
self._add_eos_token = value
|
||||||
|
self.update_post_processor()
|
||||||
|
|
||||||
|
@add_bos_token.setter
|
||||||
|
def add_bos_token(self, value):
|
||||||
|
self._add_bos_token = value
|
||||||
|
self.update_post_processor()
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
if not self.can_save_slow_tokenizer:
|
if not self.can_save_slow_tokenizer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -315,6 +315,39 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_fast_special_tokens(self):
|
||||||
|
slow_tokenizer = self.tokenizer
|
||||||
|
fast_tokenizer = self.rust_tokenizer
|
||||||
|
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert slow == [1, 319, 4559, 1243]
|
||||||
|
|
||||||
|
fast_tokenizer.add_eos_token = False
|
||||||
|
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert fast == [1, 319, 4559, 1243]
|
||||||
|
|
||||||
|
fast_tokenizer.add_eos_token = True
|
||||||
|
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert fast == [1, 319, 4559, 1243, 2]
|
||||||
|
|
||||||
|
slow_tokenizer.add_eos_token = True
|
||||||
|
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert slow == [1, 319, 4559, 1243, 2]
|
||||||
|
|
||||||
|
fast_tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
||||||
|
)
|
||||||
|
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert fast == [319, 4559, 1243, 2]
|
||||||
|
|
||||||
|
slow_tokenzier = LlamaTokenizer.from_pretrained(
|
||||||
|
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
||||||
|
)
|
||||||
|
slow = slow_tokenzier.encode("A sample test", add_special_tokens=True)
|
||||||
|
assert slow == [319, 4559, 1243, 2]
|
||||||
|
|
||||||
|
self.tokenizer.add_eos_token = False
|
||||||
|
self.rust_tokenizer.add_eos_token = False
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_conversion(self):
|
def test_conversion(self):
|
||||||
# This is excruciatingly slow since it has to recreate the entire merge
|
# This is excruciatingly slow since it has to recreate the entire merge
|
||||||
|
|||||||
Reference in New Issue
Block a user