From 14e970c271f8c1f21d46aaadf7e89852d329d3a8 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 9 Aug 2019 15:01:38 -0400 Subject: [PATCH] Tokenization encode/decode class-based sequence handling --- .../tests/tokenization_tests_commons.py | 5 ++- pytorch_transformers/tokenization_bert.py | 8 +++++ pytorch_transformers/tokenization_utils.py | 35 ++++++++++--------- pytorch_transformers/tokenization_xlm.py | 8 +++++ pytorch_transformers/tokenization_xlnet.py | 10 ++++++ 5 files changed, 47 insertions(+), 19 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index e766a825a0..ebcf6f48d8 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -105,7 +105,7 @@ class CommonTestCases: self.assertEqual(added_toks, len(new_toks)) self.assertEqual(all_size_2, all_size + len(new_toks)) - tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l", no_sep_cls_tokens=True) + tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l") self.assertGreaterEqual(len(tokens), 4) self.assertGreater(tokens[0], tokenizer.vocab_size - 1) self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) @@ -121,8 +121,7 @@ class CommonTestCases: self.assertEqual(added_toks_2, len(new_toks_2)) self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) - tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", - no_sep_cls_tokens=True) + tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l") self.assertGreaterEqual(len(tokens), 6) self.assertGreater(tokens[0], tokenizer.vocab_size - 1) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 9bf18a97d7..9f4f00a300 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -166,6 +166,14 @@ class BertTokenizer(PreTrainedTokenizer): out_string = ' '.join(tokens).replace(' ##', '').strip() return out_string + def add_special_tokens_single_sentence(self, token_ids): + return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] + + def add_special_tokens_sentences_pair(self, *token_ids): + sep = [self._convert_token_to_id(self.sep_token)] + cls = [self._convert_token_to_id(self.cls_token)] + return cls + token_ids[0] + sep + token_ids[1] + sep + def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary to a directory or file.""" index = 0 diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 232ef1c35b..a3581fe582 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -495,7 +495,7 @@ class PreTrainedTokenizer(object): """ raise NotImplementedError - def convert_tokens_to_ids(self, tokens, **kwargs): + def convert_tokens_to_ids(self, tokens): """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id (resp. a sequence of ids), using the vocabulary. """ @@ -519,31 +519,35 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - - def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False): + def encode(self, text, add_special_tokens=False, *sequences): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``. """ - if len(text) == 1: - return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens) + if len(sequences) == 0: + if add_special_tokens: + return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text))) + else: + return self.convert_tokens_to_ids(self.tokenize(text)) - if len(text) > 2: + if len(sequences) > 1: logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the " "initial two.") - first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])] - second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])] - sep = [self._convert_token_to_id(self.sep_token)] - cls = [self._convert_token_to_id(self.cls_token)] - n_sep_token = 2 if double_sep_token else 1 + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])] - tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep - tokens = (tokens + cls) if cls_token_at_end else (cls + tokens) + if add_special_tokens: + return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) + else: + return first_sentence_tokens, second_sentence_tokens - return tokens + def add_special_tokens_single_sentence(self, token_ids): + raise NotImplementedError + def add_special_tokens_sentences_pair(self, *token_ids): + raise NotImplementedError def convert_ids_to_tokens(self, ids, skip_special_tokens=False): """ Converts a single index or a sequence of indices (integers) in a token " @@ -577,8 +581,7 @@ class PreTrainedTokenizer(object): """ return ' '.join(self.convert_ids_to_tokens(tokens)) - def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False, - double_sep_token=False): + def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary with options to remove special tokens and clean up tokenization spaces. diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 899f6b884f..b0b8f1d78d 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -214,6 +214,14 @@ class XLMTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace('', ' ').strip() return out_string + def add_special_tokens_single_sentence(self, token_ids): + return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] + + def add_special_tokens_sentences_pair(self, *token_ids): + sep = [self._convert_token_to_id(self.sep_token)] + cls = [self._convert_token_to_id(self.cls_token)] + return cls + token_ids[0] + sep + token_ids[1] + sep + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index 919ac97bce..42473da860 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -177,6 +177,16 @@ class XLNetTokenizer(PreTrainedTokenizer): out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() return out_string + def add_special_tokens_single_sentence(self, token_ids): + logger.warning("No method was defined for special tokens and single sentence streams in XLNet. " + "Returning token_ids") + return token_ids + + def add_special_tokens_sentences_pair(self, *token_ids): + sep = [self._convert_token_to_id(self.sep_token)] + cls = [self._convert_token_to_id(self.cls_token)] + return token_ids[0] + sep + token_ids[1] + sep + cls + def save_vocabulary(self, save_directory): """ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.