diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index e552407689..8b34a43e5a 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -104,16 +104,23 @@ class BertTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", - mask_token="[MASK]", **kwargs): + mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): """Constructs a BertTokenizer. Args: - vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input - Only has an effect when do_wordpiece_only=False - do_basic_tokenize: Whether to do basic tokenization before wordpiece. - never_split: List of tokens which will never be split during tokenization. - Only has an effect when do_wordpiece_only=False + **vocab_file**: Path to a one-wordpiece-per-line vocabulary file + **do_lower_case**: (`optional`) boolean (default True) + Whether to lower case the input + Only has an effect when do_basic_tokenize=True + **do_basic_tokenize**: (`optional`) boolean (default True) + Whether to do basic tokenization before wordpiece. + **never_split**: (`optional`) list of string + List of tokens which will never be split during tokenization. + Only has an effect when do_basic_tokenize=True + **tokenize_chinese_chars**: (`optional`) boolean (default True) + Whether to tokenize Chinese characters. + This should likely be desactivated for Japanese: + see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 """ super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, cls_token=cls_token, @@ -127,8 +134,9 @@ class BertTokenizer(PreTrainedTokenizer): [(ids, tok) for tok, ids in self.vocab.items()]) self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: - self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, - never_split=never_split) + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) @property @@ -196,21 +204,36 @@ class BertTokenizer(PreTrainedTokenizer): class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - def __init__(self, - do_lower_case=True, - never_split=None): - """Constructs a BasicTokenizer. + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): + """ Constructs a BasicTokenizer. Args: - do_lower_case: Whether to lower case the input. + **do_lower_case**: Whether to lower case the input. + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. + Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) + List of token not to split. + **tokenize_chinese_chars**: (`optional`) boolean (default True) + Whether to tokenize Chinese characters. + This should likely be desactivated for Japanese: + see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 """ if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = never_split + self.tokenize_chinese_chars = tokenize_chinese_chars - def tokenize(self, text, never_split=None, tokenize_chinese_chars=True): - """Tokenizes a piece of text.""" + def tokenize(self, text, never_split=None): + """ Basic Tokenization of a piece of text. + Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. + + Args: + **never_split**: (`optional`) list of str + Kept for backward compatibility purposes. + Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) + List of token not to split. + """ never_split = self.never_split + (never_split if never_split is not None else []) text = self._clean_text(text) # This was added on November 1st, 2018 for the multilingual and Chinese @@ -219,7 +242,7 @@ class BasicTokenizer(object): # and generally don't have any Chinese data in them (there are Chinese # characters in the vocabulary because Wikipedia does have some Chinese # words in the English Wikipedia.). - if tokenize_chinese_chars: + if self.tokenize_chinese_chars: text = self._tokenize_chinese_chars(text) orig_tokens = whitespace_tokenize(text) split_tokens = []