[SPM] Patch spm Llama and T5 (#25656)
* hot fix * only encode with string prefix if starts with prefix * styling * add a new test * fixup
This commit is contained in:
@@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
|||||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||||
"""
|
"""
|
||||||
if self.legacy:
|
|
||||||
return self.sp_model.encode(text, out_type=str)
|
|
||||||
|
|
||||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
|
||||||
text = self.unk_token + text
|
|
||||||
tokens = self.sp_model.encode(text, out_type=str)
|
tokens = self.sp_model.encode(text, out_type=str)
|
||||||
return tokens[unk_token_length:]
|
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||||
|
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||||
|
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||||
|
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
|||||||
@@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
tokens = tokens[1:]
|
tokens = tokens[1:]
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unk_token_length(self):
|
||||||
|
return len(self.sp_model.encode(str(self.unk_token)))
|
||||||
|
|
||||||
def _tokenize(self, text, **kwargs):
|
def _tokenize(self, text, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a tokenized string.
|
Returns a tokenized string.
|
||||||
@@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||||
"""
|
"""
|
||||||
if self.legacy:
|
|
||||||
return self.sp_model.encode(text, out_type=str)
|
|
||||||
|
|
||||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
|
||||||
text = self.unk_token + text
|
|
||||||
tokens = self.sp_model.encode(text, out_type=str)
|
tokens = self.sp_model.encode(text, out_type=str)
|
||||||
return tokens[unk_token_length:]
|
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||||
|
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||||
|
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||||
|
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
|
|||||||
@@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
decoded_tokens = tokenizer.decode(input_ids)
|
decoded_tokens = tokenizer.decode(input_ids)
|
||||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||||
|
|
||||||
|
def test_some_edge_cases(self):
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||||
|
|
||||||
|
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
|
||||||
|
self.assertEqual(sp_tokens, ["<", "s", ">>"])
|
||||||
|
tokens = tokenizer.tokenize("<s>>")
|
||||||
|
self.assertNotEqual(sp_tokens, tokens)
|
||||||
|
self.assertEqual(tokens, ["<s>", ">"])
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
|
|||||||
Reference in New Issue
Block a user