[tokenizers] Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer (#35593)

* Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer

in PreTrainedTokenizerFast, rather than relying on subclasses to take care of this.

* Simplify setting self.add_prefix_space, ensure pre_tok exists

* Wrap in try-except to catch 'Custom PreTokenizer cannot be serialized'

862d1a346a/bindings/python/src/pre_tokenizers.rs (L672) produces the Exception. They're triggered by the roformer tests, as the RoFormerTokenizerFast uses a custom PreTokenizer.

* Propagate add_prefix_space in T5TokenizerFast to superclass
This commit is contained in:
Tom Aarsen
2025-01-09 17:46:50 +01:00
committed by GitHub
parent 46276f9a7f
commit 32e0db8a69
16 changed files with 36 additions and 122 deletions

View File

@@ -16,7 +16,7 @@
import json import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -157,14 +157,6 @@ class BartTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)

View File

@@ -17,7 +17,7 @@
import json import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -160,14 +160,6 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
import json
import re import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
@@ -29,7 +28,6 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from tokenizers import pre_tokenizers
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -137,14 +135,6 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
" so that the fast tokenizer works correctly." " so that the fast tokenizer works correctly."
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False) is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, ( assert self.add_prefix_space or not is_split_into_words, (

View File

@@ -14,11 +14,8 @@
# limitations under the License. # limitations under the License.
"""Fast Tokenization class for model DeBERTa.""" """Fast Tokenization class for model DeBERTa."""
import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
@@ -132,14 +129,6 @@ class DebertaTokenizerFast(PreTrainedTokenizerFast):
) )
self.add_bos_token = kwargs.pop("add_bos_token", False) self.add_bos_token = kwargs.pop("add_bos_token", False)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
@property @property
def mask_token(self) -> str: def mask_token(self) -> str:
""" """

View File

@@ -14,11 +14,8 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
import json
from typing import Optional, Tuple from typing import Optional, Tuple
from tokenizers import pre_tokenizers
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
@@ -109,14 +106,6 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
self.add_bos_token = kwargs.pop("add_bos_token", False) self.add_bos_token = kwargs.pop("add_bos_token", False)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False) is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, ( assert self.add_prefix_space or not is_split_into_words, (

View File

@@ -14,10 +14,9 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for GPTNeoX.""" """Tokenization classes for GPTNeoX."""
import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
@@ -122,14 +121,6 @@ class GPTNeoXTokenizerFast(PreTrainedTokenizerFast):
self._add_eos_token = add_eos_token self._add_eos_token = add_eos_token
self.update_post_processor() self.update_post_processor()
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
@property @property
def add_eos_token(self): def add_eos_token(self):
return self._add_eos_token return self._add_eos_token

View File

@@ -20,7 +20,7 @@ and _encode_plus, in which the Rust tokenizer is used.
import json import json
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import ( from ...tokenization_utils_base import (
BatchEncoding, BatchEncoding,
@@ -162,14 +162,6 @@ class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:

View File

@@ -17,7 +17,7 @@
import json import json
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -157,14 +157,6 @@ class LEDTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)

View File

@@ -17,7 +17,7 @@
import json import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -155,14 +155,6 @@ class LongformerTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:

View File

@@ -21,7 +21,7 @@ import json
from functools import lru_cache from functools import lru_cache
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
from ...tokenization_utils_base import ( from ...tokenization_utils_base import (
@@ -207,14 +207,6 @@ class MarkupLMTokenizerFast(PreTrainedTokenizerFast):
self.tags_dict = tags_dict self.tags_dict = tags_dict
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:

View File

@@ -16,7 +16,7 @@
import json import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -160,14 +160,6 @@ class MvpTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)

View File

@@ -17,7 +17,7 @@
import json import json
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from tokenizers import pre_tokenizers, processors from tokenizers import processors
from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -154,14 +154,6 @@ class RobertaTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:

View File

@@ -124,6 +124,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
pad_token=pad_token, pad_token=pad_token,
extra_ids=extra_ids, extra_ids=extra_ids,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
add_prefix_space=add_prefix_space,
**kwargs, **kwargs,
) )

View File

@@ -22,7 +22,7 @@ from functools import lru_cache
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
from tokenizers import AddedToken, pre_tokenizers, processors from tokenizers import AddedToken, processors
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -128,19 +128,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self.add_bos_token = kwargs.pop("add_bos_token", False) self.add_bos_token = kwargs.pop("add_bos_token", False)
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
if normalizer_file is not None: if normalizer_file is not None:
with open(normalizer_file, encoding="utf-8") as vocab_handle: with open(normalizer_file, encoding="utf-8") as vocab_handle:
self.english_spelling_normalizer = json.load(vocab_handle) self.english_spelling_normalizer = json.load(vocab_handle)
else: else:
self.english_spelling_normalizer = None self.english_spelling_normalizer = None
self.add_prefix_space = add_prefix_space
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language self.language = language

View File

@@ -102,6 +102,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
fast_tokenizer_file = kwargs.pop("tokenizer_file", None) fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
from_slow = kwargs.pop("from_slow", False) from_slow = kwargs.pop("from_slow", False)
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
self.add_prefix_space = kwargs.get("add_prefix_space", False)
if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None: if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
raise ValueError( raise ValueError(
@@ -206,6 +207,18 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
if tokens: if tokens:
self.add_tokens(tokens) self.add_tokens(tokens)
try:
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = self.add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
except Exception:
# We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can
# not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer
# for which we need to update the `add_prefix_space` attribute.
pass
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
return True return True

View File

@@ -4684,3 +4684,15 @@ class TokenizerTesterMixin:
with self.assertRaises(AttributeError, msg="conflicts with the method"): with self.assertRaises(AttributeError, msg="conflicts with the method"):
get_tokenizer_func(get_vocab=True) get_tokenizer_func(get_vocab=True)
@parameterized.expand([(True,), (False,)])
def test_rust_tokenizer_add_prefix_space(self, add_prefix_space):
if not self.test_rust_tokenizer:
self.skipTest(reason="test_rust_tokenizer is set to False")
for tokenizer, pretrained_name, _ in self.tokenizers_list:
fast_tokenizer = tokenizer.from_pretrained(pretrained_name, add_prefix_space=add_prefix_space)
self.assertEqual(fast_tokenizer.add_prefix_space, add_prefix_space)
# Only the ByteLevel pre-tokenizer has the `add_prefix_space` attribute, we have to ensure that it's set correctly
if hasattr(fast_tokenizer.backend_tokenizer.pre_tokenizer, "add_prefix_space"):
self.assertEqual(fast_tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space, add_prefix_space)