[LlamaTokenizerFast] Adds edge cases for the template processor (#26606)
* make sure eos and bos are properly handled for fast tokenizer * fix code llama as well * nits * fix the conversion script as well * fix failing test
This commit is contained in:
@@ -1192,32 +1192,8 @@ class LlamaConverter(SpmConverter):
|
||||
return None
|
||||
|
||||
def post_processor(self):
|
||||
# 3 possible case :
|
||||
# - add_bos and add_eos : '<s>:0 $A:0 </s>:0' and '<s>:0 $A:0 </s>:0 <s>:1 $B:1 </s>:1'
|
||||
# - add_bos: '<s>:0 $A:0' and '<s>:0 $A:0 <s>:1 $B:1'
|
||||
# - add_eos: '$A:0 </s>:0' and '$A:0 </s>:0 $B:1 </s>:1'
|
||||
|
||||
add_bos = self.original_tokenizer.add_bos_token
|
||||
add_eos = self.original_tokenizer.add_eos_token
|
||||
if add_bos or add_eos:
|
||||
bos = self.original_tokenizer.bos_token
|
||||
bos_token_id = self.original_tokenizer.bos_token_id
|
||||
|
||||
eos = self.original_tokenizer.eos_token
|
||||
eos_token_id = self.original_tokenizer.eos_token_id
|
||||
|
||||
single = f"{(bos+':0 ') * add_bos}$A:0{(' '+eos+':0') if add_eos else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') * add_bos} $B:1{(' '+eos+':1') if add_eos else ''}"
|
||||
|
||||
special_tokens = []
|
||||
if add_bos:
|
||||
special_tokens.append((bos, bos_token_id))
|
||||
if add_eos:
|
||||
special_tokens.append((eos, eos_token_id))
|
||||
return processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)
|
||||
|
||||
else:
|
||||
return None
|
||||
# the processor is defined in the LlamaTokenizerFast class.
|
||||
return None
|
||||
|
||||
|
||||
class MarkupLMConverter(Converter):
|
||||
|
||||
@@ -178,12 +178,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
bos = self.bos_token
|
||||
bos_token_id = self.bos_token_id
|
||||
if bos is None and self.add_bos_token:
|
||||
raise ValueError("add_bos_token = True but bos_token = None")
|
||||
|
||||
eos = self.eos_token
|
||||
eos_token_id = self.eos_token_id
|
||||
if eos is None and self.add_eos_token:
|
||||
raise ValueError("add_eos_token = True but eos_token = None")
|
||||
|
||||
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||
|
||||
special_tokens = []
|
||||
if self.add_bos_token:
|
||||
|
||||
@@ -145,12 +145,16 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
bos = self.bos_token
|
||||
bos_token_id = self.bos_token_id
|
||||
if bos is None and self.add_bos_token:
|
||||
raise ValueError("add_bos_token = True but bos_token = None")
|
||||
|
||||
eos = self.eos_token
|
||||
eos_token_id = self.eos_token_id
|
||||
if eos is None and self.add_eos_token:
|
||||
raise ValueError("add_eos_token = True but eos_token = None")
|
||||
|
||||
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||
|
||||
special_tokens = []
|
||||
if self.add_bos_token:
|
||||
|
||||
@@ -582,6 +582,19 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
||||
|
||||
def test_fast_post_processor(self):
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False
|
||||
)
|
||||
tokenizer.encode(" Hey ")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer = LlamaTokenizerFast(
|
||||
SAMPLE_VOCAB, bos_token=None, eos_token="<s>", add_bos_token=True, add_eos_token=False
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer = LlamaTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True)
|
||||
|
||||
@require_jinja
|
||||
def test_tokenization_for_chat(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
Reference in New Issue
Block a user