From 65aabafe2ff2735f6351b23983f4cec45dbb134c Mon Sep 17 00:00:00 2001 From: Tianqi Liu <31980222+andyl98@users.noreply.github.com> Date: Fri, 6 Oct 2023 01:49:02 -0700 Subject: [PATCH] Update tokenization_code_llama_fast.py (#26576) * Update tokenization_code_llama_fast.py * Update test_tokenization_code_llama.py * Update test_tokenization_code_llama.py --- .../code_llama/tokenization_code_llama_fast.py | 2 +- .../code_llama/test_tokenization_code_llama.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py index 66a312eb3d..7d1e237022 100644 --- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py +++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py @@ -278,7 +278,7 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast): special_tokens = [(self.bos_token, self.bos_token_id)] if self.add_bos_token and add_special_tokens else [] if suffix_first: # format as "
 {suf}  {pre}"
-            pair += [self.prefix_token, self.suffix_token, "$A", self.middle_token, "$B"]
+            pair += [self.prefix_token, self.suffix_token, "$B", self.middle_token, "$A"]
             special_tokens += [
                 (self.prefix_token, self.prefix_id),
                 (self.suffix_token, self.suffix_id),
diff --git a/tests/models/code_llama/test_tokenization_code_llama.py b/tests/models/code_llama/test_tokenization_code_llama.py
index 3df0c552c0..2673981527 100644
--- a/tests/models/code_llama/test_tokenization_code_llama.py
+++ b/tests/models/code_llama/test_tokenization_code_llama.py
@@ -643,3 +643,15 @@ end
         input_ids = tokenizer.encode(PROMPTS[0])
         self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix))
         self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix))
+
+        # Adding suffix_first check for infilling tasks
+        suffix_first_formatted_prompt = tokenizer.tokenize(PROMPTS[0], suffix_first=True)
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0], suffix_first=True))
+        prefix, suffix = PROMPTS[0].split("")
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer.tokenize(prefix, suffix, suffix_first=True))
+        self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(prefix, suffix, suffix_first=True))
+
+        prefix, suffix = PROMPTS[0].split("")
+        suffix_first_input_ids = tokenizer.encode(PROMPTS[0], suffix_first=True)
+        self.assertEqual(suffix_first_input_ids, tokenizer.encode(prefix, suffix=suffix, suffix_first=True))
+        self.assertEqual(suffix_first_input_ids, tokenizer_fast.encode(prefix, suffix=suffix, suffix_first=True))