[LlamaTokenizerFast] Adds edge cases for the template processor (#26606)
* make sure eos and bos are properly handled for fast tokenizer * fix code llama as well * nits * fix the conversion script as well * fix failing test
This commit is contained in:
@@ -1192,31 +1192,7 @@ class LlamaConverter(SpmConverter):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def post_processor(self):
|
def post_processor(self):
|
||||||
# 3 possible case :
|
# the processor is defined in the LlamaTokenizerFast class.
|
||||||
# - add_bos and add_eos : '<s>:0 $A:0 </s>:0' and '<s>:0 $A:0 </s>:0 <s>:1 $B:1 </s>:1'
|
|
||||||
# - add_bos: '<s>:0 $A:0' and '<s>:0 $A:0 <s>:1 $B:1'
|
|
||||||
# - add_eos: '$A:0 </s>:0' and '$A:0 </s>:0 $B:1 </s>:1'
|
|
||||||
|
|
||||||
add_bos = self.original_tokenizer.add_bos_token
|
|
||||||
add_eos = self.original_tokenizer.add_eos_token
|
|
||||||
if add_bos or add_eos:
|
|
||||||
bos = self.original_tokenizer.bos_token
|
|
||||||
bos_token_id = self.original_tokenizer.bos_token_id
|
|
||||||
|
|
||||||
eos = self.original_tokenizer.eos_token
|
|
||||||
eos_token_id = self.original_tokenizer.eos_token_id
|
|
||||||
|
|
||||||
single = f"{(bos+':0 ') * add_bos}$A:0{(' '+eos+':0') if add_eos else ''}"
|
|
||||||
pair = f"{single}{(' '+bos+':1') * add_bos} $B:1{(' '+eos+':1') if add_eos else ''}"
|
|
||||||
|
|
||||||
special_tokens = []
|
|
||||||
if add_bos:
|
|
||||||
special_tokens.append((bos, bos_token_id))
|
|
||||||
if add_eos:
|
|
||||||
special_tokens.append((eos, eos_token_id))
|
|
||||||
return processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -178,12 +178,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
bos = self.bos_token
|
bos = self.bos_token
|
||||||
bos_token_id = self.bos_token_id
|
bos_token_id = self.bos_token_id
|
||||||
|
if bos is None and self.add_bos_token:
|
||||||
|
raise ValueError("add_bos_token = True but bos_token = None")
|
||||||
|
|
||||||
eos = self.eos_token
|
eos = self.eos_token
|
||||||
eos_token_id = self.eos_token_id
|
eos_token_id = self.eos_token_id
|
||||||
|
if eos is None and self.add_eos_token:
|
||||||
|
raise ValueError("add_eos_token = True but eos_token = None")
|
||||||
|
|
||||||
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||||
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||||
|
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
if self.add_bos_token:
|
if self.add_bos_token:
|
||||||
|
|||||||
@@ -145,12 +145,16 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
bos = self.bos_token
|
bos = self.bos_token
|
||||||
bos_token_id = self.bos_token_id
|
bos_token_id = self.bos_token_id
|
||||||
|
if bos is None and self.add_bos_token:
|
||||||
|
raise ValueError("add_bos_token = True but bos_token = None")
|
||||||
|
|
||||||
eos = self.eos_token
|
eos = self.eos_token
|
||||||
eos_token_id = self.eos_token_id
|
eos_token_id = self.eos_token_id
|
||||||
|
if eos is None and self.add_eos_token:
|
||||||
|
raise ValueError("add_eos_token = True but eos_token = None")
|
||||||
|
|
||||||
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
||||||
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
||||||
|
|
||||||
special_tokens = []
|
special_tokens = []
|
||||||
if self.add_bos_token:
|
if self.add_bos_token:
|
||||||
|
|||||||
@@ -582,6 +582,19 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
# a dummy prefix space is not added by the sp_model as it was de-activated
|
# a dummy prefix space is not added by the sp_model as it was de-activated
|
||||||
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
|
||||||
|
|
||||||
|
def test_fast_post_processor(self):
|
||||||
|
tokenizer = LlamaTokenizerFast(
|
||||||
|
SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False
|
||||||
|
)
|
||||||
|
tokenizer.encode(" Hey ")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tokenizer = LlamaTokenizerFast(
|
||||||
|
SAMPLE_VOCAB, bos_token=None, eos_token="<s>", add_bos_token=True, add_eos_token=False
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tokenizer = LlamaTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True)
|
||||||
|
|
||||||
@require_jinja
|
@require_jinja
|
||||||
def test_tokenization_for_chat(self):
|
def test_tokenization_for_chat(self):
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user