From 9ce36e3e4b0b17dd6df05e13e563570677cda39e Mon Sep 17 00:00:00 2001 From: samvelyan Date: Wed, 14 Aug 2019 08:57:09 +0000 Subject: [PATCH] Re-implemented tokenize() iteratively in PreTrainedTokenizer. --- pytorch_transformers/tokenization_utils.py | 42 ++++++++++++++++++---- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 2e75c83bfb..bdeeeb4877 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -428,7 +428,7 @@ class PreTrainedTokenizer(object): Parameters: special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``]. - + Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them). Returns: @@ -472,15 +472,45 @@ class PreTrainedTokenizer(object): Take care of added tokens. """ + def split_on_token(tok, text): + result = [] + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + sub_text = sub_text.strip() + if i == 0 and not sub_text: + result += [tok] + elif i == len(split_text) - 1: + if sub_text: + result += [sub_text] + else: + pass + else: + if sub_text: + result += [sub_text] + result += [tok] + return result + def split_on_tokens(tok_list, text): if not text: return [] if not tok_list: return self._tokenize(text, **kwargs) - tok = tok_list[0] - split_text = text.split(tok) - return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \ - for sub_text in split_text), [])[:-1] + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self.added_tokens_encoder \ + and sub_text not in self.all_special_tokens: + tokenized_text += split_on_token(tok, sub_text) + else: + tokenized_text += [sub_text] + text_list = tokenized_text + + return sum((self._tokenize(token, **kwargs) if token not \ + in self.added_tokens_encoder and token not in self.all_special_tokens \ + else [token] for token in tokenized_text), []) added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens tokenized_text = split_on_tokens(added_tokens, text) @@ -522,7 +552,7 @@ class PreTrainedTokenizer(object): def encode(self, text): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. - + Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. """ return self.convert_tokens_to_ids(self.tokenize(text))