From 9ad815e412756bb55dbfb57de1cf54503a4b2f23 Mon Sep 17 00:00:00 2001
From: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Date: Fri, 6 Oct 2023 16:40:54 +0200
Subject: [PATCH] [`LlamaTokenizerFast`] Adds edge cases for the template
processor (#26606)
* make sure eos and bos are properly handled for fast tokenizer
* fix code llama as well
* nits
* fix the conversion script as well
* fix failing test
---
src/transformers/convert_slow_tokenizer.py | 28 ++-----------------
.../tokenization_code_llama_fast.py | 8 ++++--
.../models/llama/tokenization_llama_fast.py | 8 ++++--
tests/models/llama/test_tokenization_llama.py | 13 +++++++++
4 files changed, 27 insertions(+), 30 deletions(-)
diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py
index cf5e8ca17f..a2195d9cae 100644
--- a/src/transformers/convert_slow_tokenizer.py
+++ b/src/transformers/convert_slow_tokenizer.py
@@ -1192,32 +1192,8 @@ class LlamaConverter(SpmConverter):
return None
def post_processor(self):
- # 3 possible case :
- # - add_bos and add_eos : ':0 $A:0 :0' and ':0 $A:0 :0 :1 $B:1 :1'
- # - add_bos: ':0 $A:0' and ':0 $A:0 :1 $B:1'
- # - add_eos: '$A:0 :0' and '$A:0 :0 $B:1 :1'
-
- add_bos = self.original_tokenizer.add_bos_token
- add_eos = self.original_tokenizer.add_eos_token
- if add_bos or add_eos:
- bos = self.original_tokenizer.bos_token
- bos_token_id = self.original_tokenizer.bos_token_id
-
- eos = self.original_tokenizer.eos_token
- eos_token_id = self.original_tokenizer.eos_token_id
-
- single = f"{(bos+':0 ') * add_bos}$A:0{(' '+eos+':0') if add_eos else ''}"
- pair = f"{single}{(' '+bos+':1') * add_bos} $B:1{(' '+eos+':1') if add_eos else ''}"
-
- special_tokens = []
- if add_bos:
- special_tokens.append((bos, bos_token_id))
- if add_eos:
- special_tokens.append((eos, eos_token_id))
- return processors.TemplateProcessing(single=single, pair=pair, special_tokens=special_tokens)
-
- else:
- return None
+ # the processor is defined in the LlamaTokenizerFast class.
+ return None
class MarkupLMConverter(Converter):
diff --git a/src/transformers/models/code_llama/tokenization_code_llama_fast.py b/src/transformers/models/code_llama/tokenization_code_llama_fast.py
index 7d1e237022..5e8a7945dc 100644
--- a/src/transformers/models/code_llama/tokenization_code_llama_fast.py
+++ b/src/transformers/models/code_llama/tokenization_code_llama_fast.py
@@ -178,12 +178,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
- single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
- pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
diff --git a/src/transformers/models/llama/tokenization_llama_fast.py b/src/transformers/models/llama/tokenization_llama_fast.py
index 1d310507f5..6e9cd2aa3b 100644
--- a/src/transformers/models/llama/tokenization_llama_fast.py
+++ b/src/transformers/models/llama/tokenization_llama_fast.py
@@ -145,12 +145,16 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
- single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
- pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
diff --git a/tests/models/llama/test_tokenization_llama.py b/tests/models/llama/test_tokenization_llama.py
index e568414a7b..008ec83c65 100644
--- a/tests/models/llama/test_tokenization_llama.py
+++ b/tests/models/llama/test_tokenization_llama.py
@@ -582,6 +582,19 @@ class LlamaIntegrationTest(unittest.TestCase):
# a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
+ def test_fast_post_processor(self):
+ tokenizer = LlamaTokenizerFast(
+ SAMPLE_VOCAB, eos_token=None, bos_token=None, add_bos_token=False, add_eos_token=False
+ )
+ tokenizer.encode(" Hey ")
+
+ with self.assertRaises(ValueError):
+ tokenizer = LlamaTokenizerFast(
+ SAMPLE_VOCAB, bos_token=None, eos_token="", add_bos_token=True, add_eos_token=False
+ )
+ with self.assertRaises(ValueError):
+ tokenizer = LlamaTokenizerFast(SAMPLE_VOCAB, eos_token=None, add_bos_token=True, add_eos_token=True)
+
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)