Tokenization encode/decode class-based sequence handling

This commit is contained in:
LysandreJik
2019-08-09 15:01:38 -04:00
parent fbd746bd06
commit 14e970c271
5 changed files with 47 additions and 19 deletions

View File

@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
"""
raise NotImplementedError
def convert_tokens_to_ids(self, tokens, **kwargs):
def convert_tokens_to_ids(self, tokens):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
@@ -519,31 +519,35 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False):
def encode(self, text, add_special_tokens=False, *sequences):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
"""
if len(text) == 1:
return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens)
if len(sequences) == 0:
if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
else:
return self.convert_tokens_to_ids(self.tokenize(text))
if len(text) > 2:
if len(sequences) > 1:
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
"initial two.")
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])]
sep = [self._convert_token_to_id(self.sep_token)]
cls = [self._convert_token_to_id(self.cls_token)]
n_sep_token = 2 if double_sep_token else 1
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])]
tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep
tokens = (tokens + cls) if cls_token_at_end else (cls + tokens)
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
else:
return first_sentence_tokens, second_sentence_tokens
return tokens
def add_special_tokens_single_sentence(self, token_ids):
raise NotImplementedError
def add_special_tokens_sentences_pair(self, *token_ids):
raise NotImplementedError
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
@@ -577,8 +581,7 @@ class PreTrainedTokenizer(object):
"""
return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False,
double_sep_token=False):
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.