BertTokenizerFast
This commit is contained in:
@@ -103,7 +103,7 @@ from .pipelines import (
|
||||
)
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_auto import AutoTokenizer
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import os
|
||||
import unicodedata
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -525,3 +525,54 @@ def _is_punctuation(char):
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
class BertTokenizerFast(FastPreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
|
||||
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
|
||||
mask_token="[MASK]", tokenize_chinese_chars=True,
|
||||
max_length=None, pad_to_max_length=False, stride=0,
|
||||
truncation_strategy='longest_first', add_special_tokens=True, **kwargs):
|
||||
|
||||
try:
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
|
||||
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
|
||||
pad_token=pad_token, cls_token=cls_token,
|
||||
mask_token=mask_token, **kwargs)
|
||||
|
||||
self._tokenizer = Tokenizer(models.WordPiece.from_files(
|
||||
vocab_file,
|
||||
unk_token=unk_token
|
||||
))
|
||||
self._update_special_tokens()
|
||||
self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new(
|
||||
do_basic_tokenize=do_basic_tokenize,
|
||||
do_lower_case=do_lower_case,
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
never_split=never_split if never_split is not None else [],
|
||||
))
|
||||
self._tokenizer.with_decoder(decoders.WordPiece.new())
|
||||
|
||||
if add_special_tokens:
|
||||
self._tokenizer.with_post_processor(processors.BertProcessing.new(
|
||||
(sep_token, self._tokenizer.token_to_id(sep_token)),
|
||||
(cls_token, self._tokenizer.token_to_id(cls_token)),
|
||||
))
|
||||
if max_length is not None:
|
||||
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
|
||||
self._tokenizer.with_padding(
|
||||
max_length if pad_to_max_length else None,
|
||||
self.padding_side,
|
||||
self.pad_token_id,
|
||||
self.pad_token_type_id,
|
||||
self.pad_token
|
||||
)
|
||||
self._decoder = decoders.WordPiece.new()
|
||||
|
||||
except (AttributeError, ImportError) as e:
|
||||
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
|
||||
raise e
|
||||
Reference in New Issue
Block a user