From bab3331906484db48be93e0f4768b66098db9494 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 2 Oct 2023 18:29:27 +0200 Subject: [PATCH] Code-llama-nit (#26300) * fix encoding when the fill token is None * add tests and edge cases * fiuxp * Update tests/models/code_llama/test_tokenization_code_llama.py --- .../code_llama/tokenization_code_llama_fast.py | 2 +- .../code_llama/test_tokenization_code_llama.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index 91a1896c3c..66a312eb3d 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -303,7 +303,7 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast): def encode_plus(self, text, text_pair=None, suffix_first=False, add_special_tokens=True, **kwargs): # hack to make sure the input is pre-process but outside rust text_pair = kwargs.pop("suffix", text_pair) - if self.fill_token in text and text_pair is None: + if self.fill_token is not None and self.fill_token in text and text_pair is None: text, text_pair = text.split(self.fill_token) if text_pair is None or len(text_pair) < 1: diff --git a/tests/models/code_llama/test_tokenization_code_llama.py b/tests/models/code_llama/test_tokenization_code_llama.py index beab1a5b1b..3df0c552c0 100644 --- a/tests/models/code_llama/test_tokenization_code_llama.py +++ b/tests/models/code_llama/test_tokenization_code_llama.py @@ -559,6 +559,24 @@ class LlamaIntegrationTest(unittest.TestCase): decoded_tokens = tokenizer.decode(input_ids) self.assertEqual(decoded_tokens, " Hello how") + def test_fill_token(self): + tokenizer = CodeLlamaTokenizerFast.from_pretrained( + "codellama/CodeLlama-7b-hf", fill_token=None, prefix_token=None, suffix_token=None, middle_token=None + ) + tokenizer.encode_plus("Hey how are you").input_ids + tokenizer.fill_token = "" + with self.assertRaises(ValueError): + tokenizer.encode("Hey how are you") + tokenizer.encode_plus("Hey how are you", "mne too") + tokenizer.tokenize("Hey how are you", "mne too") + + tokenizer = CodeLlamaTokenizerFast.from_pretrained( + "codellama/CodeLlama-7b-hf", revision="3773f63b4511b9e47a9a7ffc765eed7eb0169486" + ) + tokenizer.encode("Hey how are you") + tokenizer.encode_plus("Hey how are you", "mne too") + tokenizer.tokenize("Hey how are you", "mne too") + def test_spm_edge_cases(self): # the word inform should be split as ['in', 'form'] tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)