include changes from llama (#26260)
* include changes from llama * add a test
This commit is contained in:
@@ -293,6 +293,8 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
|
|||||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||||
"""
|
"""
|
||||||
tokens = self.sp_model.encode(text, out_type=str)
|
tokens = self.sp_model.encode(text, out_type=str)
|
||||||
|
if not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||||
|
return tokens
|
||||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||||
|
|||||||
@@ -559,6 +559,18 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
decoded_tokens = tokenizer.decode(input_ids)
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||||
|
|
||||||
|
def test_spm_edge_cases(self):
|
||||||
|
# the word inform should be split as ['in', 'form']
|
||||||
|
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)
|
||||||
|
tokens = tokenizer.tokenize("[INST] How are you doing?<s>[/INST]")
|
||||||
|
self.assertEqual(
|
||||||
|
tokens, ["▁[", "INST", "]", "▁How", "▁are", "▁you", "▁doing", "?", "<s>", "[", "/", "INST", "]"]
|
||||||
|
)
|
||||||
|
inputs_ids = tokenizer.encode("[INST] How are you doing?<s>[/INST]")
|
||||||
|
self.assertEqual(
|
||||||
|
inputs_ids, [1, 518, 25580, 29962, 1128, 526, 366, 2599, 29973, 1, 29961, 29914, 25580, 29962]
|
||||||
|
)
|
||||||
|
|
||||||
def test_infilling_tokenization(self):
|
def test_infilling_tokenization(self):
|
||||||
PROMPTS = [
|
PROMPTS = [
|
||||||
'''def remove_non_ascii(s: str) -> str:
|
'''def remove_non_ascii(s: str) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user