From 951ae99bea3bd8a37397228b6d1f57257a71a6cf Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 24 Dec 2019 12:24:24 -0500 Subject: [PATCH] BertTokenizerFast --- src/transformers/__init__.py | 2 +- src/transformers/tokenization_bert.py | 53 ++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e7cf22321b..e305c8b15b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -103,7 +103,7 @@ from .pipelines import ( ) from .tokenization_albert import AlbertTokenizer from .tokenization_auto import AutoTokenizer -from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer +from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index e2d3980c47..e2ba03594b 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -20,7 +20,7 @@ import logging import os import unicodedata -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer logger = logging.getLogger(__name__) @@ -525,3 +525,54 @@ def _is_punctuation(char): if cat.startswith("P"): return True return False + +class BertTokenizerFast(FastPreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, + unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", + mask_token="[MASK]", tokenize_chinese_chars=True, + max_length=None, pad_to_max_length=False, stride=0, + truncation_strategy='longest_first', add_special_tokens=True, **kwargs): + + try: + from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors + super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token, + pad_token=pad_token, cls_token=cls_token, + mask_token=mask_token, **kwargs) + + self._tokenizer = Tokenizer(models.WordPiece.from_files( + vocab_file, + unk_token=unk_token + )) + self._update_special_tokens() + self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new( + do_basic_tokenize=do_basic_tokenize, + do_lower_case=do_lower_case, + tokenize_chinese_chars=tokenize_chinese_chars, + never_split=never_split if never_split is not None else [], + )) + self._tokenizer.with_decoder(decoders.WordPiece.new()) + + if add_special_tokens: + self._tokenizer.with_post_processor(processors.BertProcessing.new( + (sep_token, self._tokenizer.token_to_id(sep_token)), + (cls_token, self._tokenizer.token_to_id(cls_token)), + )) + if max_length is not None: + self._tokenizer.with_truncation(max_length, stride, truncation_strategy) + self._tokenizer.with_padding( + max_length if pad_to_max_length else None, + self.padding_side, + self.pad_token_id, + self.pad_token_type_id, + self.pad_token + ) + self._decoder = decoders.WordPiece.new() + + except (AttributeError, ImportError) as e: + logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`") + raise e \ No newline at end of file