Tokenizers' encode function can output binary masks

This commit is contained in:
LysandreJik
2019-09-02 16:42:32 -04:00
parent 0d1dad6d53
commit e391d4735e
5 changed files with 39 additions and 11 deletions

View File

@@ -663,7 +663,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@@ -674,6 +674,8 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
**kwargs: passed to the `self.tokenize()` method
"""
if text_pair is None:
@@ -686,7 +688,7 @@ class PreTrainedTokenizer(object):
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
else:
return first_sentence_tokens, second_sentence_tokens
@@ -694,7 +696,7 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1):
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1