Fix style

This commit is contained in:
Anthony MOI
2019-12-24 12:43:27 -05:00
parent 951ae99bea
commit 31c56f2e0b
3 changed files with 125 additions and 72 deletions

View File

@@ -20,7 +20,7 @@ import logging
import os
import unicodedata
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer
logger = logging.getLogger(__name__)
@@ -526,42 +526,64 @@ def _is_punctuation(char):
return True
return False
class BertTokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
mask_token="[MASK]", tokenize_chinese_chars=True,
max_length=None, pad_to_max_length=False, stride=0,
truncation_strategy='longest_first', add_special_tokens=True, **kwargs):
def __init__(
self,
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
max_length=None,
pad_to_max_length=False,
stride=0,
truncation_strategy="longest_first",
add_special_tokens=True,
**kwargs
):
try:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self._tokenizer = Tokenizer(models.WordPiece.from_files(
vocab_file,
unk_token=unk_token
))
super(BertTokenizerFast, self).__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self._tokenizer = Tokenizer(models.WordPiece.from_files(vocab_file, unk_token=unk_token))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [],
))
self._tokenizer.with_pre_tokenizer(
pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [],
)
)
self._tokenizer.with_decoder(decoders.WordPiece.new())
if add_special_tokens:
self._tokenizer.with_post_processor(processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)),
))
self._tokenizer.with_post_processor(
processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)),
)
)
if max_length is not None:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
self._tokenizer.with_padding(
@@ -569,10 +591,10 @@ class BertTokenizerFast(FastPreTrainedTokenizer):
self.padding_side,
self.pad_token_id,
self.pad_token_type_id,
self.pad_token
self.pad_token,
)
self._decoder = decoders.WordPiece.new()
except (AttributeError, ImportError) as e:
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
raise e
raise e