Encode and Decode are back in the superclass. They now handle sentence pairs special tokens.
This commit is contained in:
@@ -495,7 +495,7 @@ class PreTrainedTokenizer(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
def convert_tokens_to_ids(self, tokens, **kwargs):
|
||||
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
|
||||
(resp. a sequence of ids), using the vocabulary.
|
||||
"""
|
||||
@@ -520,12 +520,29 @@ class PreTrainedTokenizer(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, *text, cls_token_at_end=False, double_sep_token=False, no_sep_cls_tokens=False):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||
"""
|
||||
return self.convert_tokens_to_ids(self.tokenize(text))
|
||||
|
||||
if len(text) == 1:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text[0]), no_sep_cls_tokens=no_sep_cls_tokens)
|
||||
|
||||
if len(text) > 2:
|
||||
logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
|
||||
"initial two.")
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[0])]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text[1])]
|
||||
sep = [self._convert_token_to_id(self.sep_token)]
|
||||
cls = [self._convert_token_to_id(self.cls_token)]
|
||||
n_sep_token = 2 if double_sep_token else 1
|
||||
|
||||
tokens = first_sentence_tokens + sep * n_sep_token + second_sentence_tokens + sep
|
||||
tokens = (tokens + cls) if cls_token_at_end else (cls + tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
@@ -560,7 +577,8 @@ class PreTrainedTokenizer(object):
|
||||
"""
|
||||
return ' '.join(self.convert_ids_to_tokens(tokens))
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, cls_token_at_end=False,
|
||||
double_sep_token=False):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
|
||||
@@ -568,9 +586,21 @@ class PreTrainedTokenizer(object):
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
text = self.convert_tokens_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = self.clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
if self.sep_token is not None and self.sep_token in text:
|
||||
text = text.replace(self.cls_token, self.sep_token)
|
||||
split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = [self.clean_up_tokenization(text) for text in split_text]
|
||||
return clean_text
|
||||
else:
|
||||
return split_text
|
||||
else:
|
||||
if clean_up_tokenization_spaces:
|
||||
clean_text = self.clean_up_tokenization(text)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
@property
|
||||
def special_tokens_map(self):
|
||||
@@ -602,7 +632,7 @@ class PreTrainedTokenizer(object):
|
||||
class attributes (cls_token, unk_token...).
|
||||
"""
|
||||
all_toks = self.all_special_tokens
|
||||
all_ids = list(self.convert_tokens_to_ids(t) for t in all_toks)
|
||||
all_ids = list(self._convert_token_to_id(t) for t in all_toks)
|
||||
return all_ids
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user