BertTokenizerFast
This commit is contained in:
@@ -103,7 +103,7 @@ from .pipelines import (
|
|||||||
)
|
)
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
from .tokenization_auto import AutoTokenizer
|
from .tokenization_auto import AutoTokenizer
|
||||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
from .tokenization_ctrl import CTRLTokenizer
|
from .tokenization_ctrl import CTRLTokenizer
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -525,3 +525,54 @@ def _is_punctuation(char):
|
|||||||
if cat.startswith("P"):
|
if cat.startswith("P"):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class BertTokenizerFast(FastPreTrainedTokenizer):
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
|
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||||
|
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||||
|
mask_token="[MASK]", tokenize_chinese_chars=True,
|
||||||
|
max_length=None, pad_to_max_length=False, stride=0,
|
||||||
|
truncation_strategy='longest_first', add_special_tokens=True, **kwargs):
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
|
||||||
|
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||||
|
pad_token=pad_token, cls_token=cls_token,
|
||||||
|
mask_token=mask_token, **kwargs)
|
||||||
|
|
||||||
|
self._tokenizer = Tokenizer(models.WordPiece.from_files(
|
||||||
|
vocab_file,
|
||||||
|
unk_token=unk_token
|
||||||
|
))
|
||||||
|
self._update_special_tokens()
|
||||||
|
self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new(
|
||||||
|
do_basic_tokenize=do_basic_tokenize,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||||
|
never_split=never_split if never_split is not None else [],
|
||||||
|
))
|
||||||
|
self._tokenizer.with_decoder(decoders.WordPiece.new())
|
||||||
|
|
||||||
|
if add_special_tokens:
|
||||||
|
self._tokenizer.with_post_processor(processors.BertProcessing.new(
|
||||||
|
(sep_token, self._tokenizer.token_to_id(sep_token)),
|
||||||
|
(cls_token, self._tokenizer.token_to_id(cls_token)),
|
||||||
|
))
|
||||||
|
if max_length is not None:
|
||||||
|
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
|
||||||
|
self._tokenizer.with_padding(
|
||||||
|
max_length if pad_to_max_length else None,
|
||||||
|
self.padding_side,
|
||||||
|
self.pad_token_id,
|
||||||
|
self.pad_token_type_id,
|
||||||
|
self.pad_token
|
||||||
|
)
|
||||||
|
self._decoder = decoders.WordPiece.new()
|
||||||
|
|
||||||
|
except (AttributeError, ImportError) as e:
|
||||||
|
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
|
||||||
|
raise e
|
||||||
Reference in New Issue
Block a user