[tokenizers] Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer (#35593)
* Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer
in PreTrainedTokenizerFast, rather than relying on subclasses to take care of this.
* Simplify setting self.add_prefix_space, ensure pre_tok exists
* Wrap in try-except to catch 'Custom PreTokenizer cannot be serialized'
862d1a346a/bindings/python/src/pre_tokenizers.rs (L672) produces the Exception. They're triggered by the roformer tests, as the RoFormerTokenizerFast uses a custom PreTokenizer.
* Propagate add_prefix_space in T5TokenizerFast to superclass
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -157,14 +157,6 @@ class BartTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -160,14 +160,6 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
if tokenizer_component_instance:
|
if tokenizer_component_instance:
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes for OpenAI GPT."""
|
"""Tokenization classes for OpenAI GPT."""
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -29,7 +28,6 @@ if TYPE_CHECKING:
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers
|
|
||||||
|
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -137,14 +135,6 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
" so that the fast tokenizer works correctly."
|
" so that the fast tokenizer works correctly."
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||||
assert self.add_prefix_space or not is_split_into_words, (
|
assert self.add_prefix_space or not is_split_into_words, (
|
||||||
|
|||||||
@@ -14,11 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fast Tokenization class for model DeBERTa."""
|
"""Fast Tokenization class for model DeBERTa."""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers
|
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -132,14 +129,6 @@ class DebertaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
)
|
)
|
||||||
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mask_token(self) -> str:
|
def mask_token(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,11 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes for OpenAI GPT."""
|
"""Tokenization classes for OpenAI GPT."""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers
|
|
||||||
|
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -109,14 +106,6 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||||
assert self.add_prefix_space or not is_split_into_words, (
|
assert self.add_prefix_space or not is_split_into_words, (
|
||||||
|
|||||||
@@ -14,10 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes for GPTNeoX."""
|
"""Tokenization classes for GPTNeoX."""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -122,14 +121,6 @@ class GPTNeoXTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
self._add_eos_token = add_eos_token
|
self._add_eos_token = add_eos_token
|
||||||
self.update_post_processor()
|
self.update_post_processor()
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def add_eos_token(self):
|
def add_eos_token(self):
|
||||||
return self._add_eos_token
|
return self._add_eos_token
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ and _encode_plus, in which the Rust tokenizer is used.
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import (
|
from ...tokenization_utils_base import (
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
@@ -162,14 +162,6 @@ class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
if tokenizer_component_instance:
|
if tokenizer_component_instance:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput
|
from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -157,14 +157,6 @@ class LEDTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -155,14 +155,6 @@ class LongformerTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
if tokenizer_component_instance:
|
if tokenizer_component_instance:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import json
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||||
from ...tokenization_utils_base import (
|
from ...tokenization_utils_base import (
|
||||||
@@ -207,14 +207,6 @@ class MarkupLMTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
self.tags_dict = tags_dict
|
self.tags_dict = tags_dict
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
if tokenizer_component_instance:
|
if tokenizer_component_instance:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -160,14 +160,6 @@ class MvpTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from tokenizers import pre_tokenizers, processors
|
from tokenizers import processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -154,14 +154,6 @@ class RobertaTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
tokenizer_component = "post_processor"
|
tokenizer_component = "post_processor"
|
||||||
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
||||||
if tokenizer_component_instance:
|
if tokenizer_component_instance:
|
||||||
|
|||||||
@@ -124,6 +124,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
extra_ids=extra_ids,
|
extra_ids=extra_ids,
|
||||||
additional_special_tokens=additional_special_tokens,
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
add_prefix_space=add_prefix_space,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from functools import lru_cache
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tokenizers import AddedToken, pre_tokenizers, processors
|
from tokenizers import AddedToken, processors
|
||||||
|
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@@ -128,19 +128,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
|
|
||||||
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
self.add_bos_token = kwargs.pop("add_bos_token", False)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
|
||||||
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
|
||||||
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
|
||||||
pre_tok_state["add_prefix_space"] = add_prefix_space
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
|
||||||
|
|
||||||
if normalizer_file is not None:
|
if normalizer_file is not None:
|
||||||
with open(normalizer_file, encoding="utf-8") as vocab_handle:
|
with open(normalizer_file, encoding="utf-8") as vocab_handle:
|
||||||
self.english_spelling_normalizer = json.load(vocab_handle)
|
self.english_spelling_normalizer = json.load(vocab_handle)
|
||||||
else:
|
else:
|
||||||
self.english_spelling_normalizer = None
|
self.english_spelling_normalizer = None
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
|
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
|
||||||
|
|
||||||
self.language = language
|
self.language = language
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
|
fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
|
||||||
from_slow = kwargs.pop("from_slow", False)
|
from_slow = kwargs.pop("from_slow", False)
|
||||||
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
|
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
|
||||||
|
self.add_prefix_space = kwargs.get("add_prefix_space", False)
|
||||||
|
|
||||||
if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
|
if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -206,6 +207,18 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
if tokens:
|
if tokens:
|
||||||
self.add_tokens(tokens)
|
self.add_tokens(tokens)
|
||||||
|
|
||||||
|
try:
|
||||||
|
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
||||||
|
if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
|
||||||
|
pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type"))
|
||||||
|
pre_tok_state["add_prefix_space"] = self.add_prefix_space
|
||||||
|
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
||||||
|
except Exception:
|
||||||
|
# We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can
|
||||||
|
# not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer
|
||||||
|
# for which we need to update the `add_prefix_space` attribute.
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_fast(self) -> bool:
|
def is_fast(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -4684,3 +4684,15 @@ class TokenizerTesterMixin:
|
|||||||
|
|
||||||
with self.assertRaises(AttributeError, msg="conflicts with the method"):
|
with self.assertRaises(AttributeError, msg="conflicts with the method"):
|
||||||
get_tokenizer_func(get_vocab=True)
|
get_tokenizer_func(get_vocab=True)
|
||||||
|
|
||||||
|
@parameterized.expand([(True,), (False,)])
|
||||||
|
def test_rust_tokenizer_add_prefix_space(self, add_prefix_space):
|
||||||
|
if not self.test_rust_tokenizer:
|
||||||
|
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||||||
|
|
||||||
|
for tokenizer, pretrained_name, _ in self.tokenizers_list:
|
||||||
|
fast_tokenizer = tokenizer.from_pretrained(pretrained_name, add_prefix_space=add_prefix_space)
|
||||||
|
self.assertEqual(fast_tokenizer.add_prefix_space, add_prefix_space)
|
||||||
|
# Only the ByteLevel pre-tokenizer has the `add_prefix_space` attribute, we have to ensure that it's set correctly
|
||||||
|
if hasattr(fast_tokenizer.backend_tokenizer.pre_tokenizer, "add_prefix_space"):
|
||||||
|
self.assertEqual(fast_tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space, add_prefix_space)
|
||||||
|
|||||||
Reference in New Issue
Block a user