Update tokenization_code_llama_fast.py (#26576)
* Update tokenization_code_llama_fast.py * Update test_tokenization_code_llama.py * Update test_tokenization_code_llama.py
This commit is contained in:
@@ -278,7 +278,7 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
|
special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else []
|
||||||
if suffix_first:
|
if suffix_first:
|
||||||
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
# format as " <PRE> <SUF>{suf} <MID> {pre}"
|
||||||
pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"]
|
pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
|
||||||
special_tokens += [
|
special_tokens += [
|
||||||
(self.prefix_token, self.prefix_id),
|
(self.prefix_token, self.prefix_id),
|
||||||
(self.suffix_token, self.suffix_id),
|
(self.suffix_token, self.suffix_id),
|
||||||
|
|||||||
@@ -643,3 +643,15 @@ end
|
|||||||
input_ids = tokenizer.encode(PROMPTS[0])
|
input_ids = tokenizer.encode(PROMPTS[0])
|
||||||
self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix))
|
self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix))
|
||||||
self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix))
|
self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix))
|
||||||
|
|
||||||
|
# Adding suffix_first check for infilling tasks
|
||||||
|
suffix_first_formatted_prompt = tokenizer.tokenize(PROMPTS[0], suffix_first=True)
|
||||||
|
self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0], suffix_first=True))
|
||||||
|
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
||||||
|
self.assertEqual(suffix_first_formatted_prompt, tokenizer.tokenize(prefix, suffix, suffix_first=True))
|
||||||
|
self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(prefix, suffix, suffix_first=True))
|
||||||
|
|
||||||
|
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
||||||
|
suffix_first_input_ids = tokenizer.encode(PROMPTS[0], suffix_first=True)
|
||||||
|
self.assertEqual(suffix_first_input_ids, tokenizer.encode(prefix, suffix=suffix, suffix_first=True))
|
||||||
|
self.assertEqual(suffix_first_input_ids, tokenizer_fast.encode(prefix, suffix=suffix, suffix_first=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user