From 51794bf21ee6c9b9a702a3bceeea167e9518880b Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 23 Aug 2023 07:16:43 +0200 Subject: [PATCH] [`SPM`] Patch `spm` Llama and T5 (#25656) * hot fix * only encode with string prefix if starts with prefix * styling * add a new test * fixup --- .../models/llama/tokenization_llama.py | 13 +++++++------ src/transformers/models/t5/tokenization_t5.py | 17 +++++++++++------ tests/models/llama/test_tokenization_llama.py | 9 +++++++++ 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/llama/tokenization_llama.py b/src/transformers/models/llama/tokenization_llama.py index 65ae8e2bd6..808bb0ea52 100644 --- a/src/transformers/models/llama/tokenization_llama.py +++ b/src/transformers/models/llama/tokenization_llama.py @@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer): `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ - if self.legacy: - return self.sp_model.encode(text, out_type=str) - - unk_token_length = len(self.sp_model.encode(str(self.unk_token))) - text = self.unk_token + text tokens = self.sp_model.encode(text, out_type=str) - return tokens[unk_token_length:] + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index 83fb861b65..9db1a6fa62 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer): tokens = tokens[1:] return tokens + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + def _tokenize(self, text, **kwargs): """ Returns a tokenized string. @@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer): `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. """ - if self.legacy: - return self.sp_model.encode(text, out_type=str) - - unk_token_length = len(self.sp_model.encode(str(self.unk_token))) - text = self.unk_token + text tokens = self.sp_model.encode(text, out_type=str) - return tokens[unk_token_length:] + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py index aad6eb7836..25b4f4d8f1 100644 --- a/tests/models/llama/test_tokenization_llama.py +++ b/tests/models/llama/test_tokenization_llama.py @@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase): decoded_tokens = tokenizer.decode(input_ids) self.assertEqual(decoded_tokens, " Hello how") + def test_some_edge_cases(self): + tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False) + + sp_tokens = tokenizer.sp_model.encode(">", out_type=str) + self.assertEqual(sp_tokens, ["<", "s", ">>"]) + tokens = tokenizer.tokenize(">") + self.assertNotEqual(sp_tokens, tokens) + self.assertEqual(tokens, ["", ">"]) + @require_sentencepiece @require_tokenizers