Tokenizers' encode function can output binary masks

This commit is contained in:
LysandreJik
2019-09-02 16:42:32 -04:00
parent 0d1dad6d53
commit e391d4735e
5 changed files with 39 additions and 11 deletions

View File

@@ -194,13 +194,19 @@ class BertTokenizer(PreTrainedTokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):

View File

@@ -88,11 +88,17 @@ class RobertaTokenizer(GPT2Tokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s> A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(sep + token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + sep + token_ids_1 + sep return cls + token_ids_0 + sep + sep + token_ids_1 + sep

View File

@@ -663,7 +663,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@@ -674,6 +674,8 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded. text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
if text_pair is None: if text_pair is None:
@@ -686,7 +688,7 @@ class PreTrainedTokenizer(object):
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
else: else:
return first_sentence_tokens, second_sentence_tokens return first_sentence_tokens, second_sentence_tokens
@@ -694,7 +696,7 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1

View File

@@ -761,13 +761,20 @@ class XLMTokenizer(PreTrainedTokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):

View File

@@ -190,13 +190,20 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids + sep + cls return token_ids + sep + cls
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS] An XLNet sequence has the following format: X [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
token_ids_0 + sep + token_ids_1 + sep + cls,
[0] * len(token_ids_0 + sep) + [1] * len(token_ids_1 + sep + cls)
)
else:
return token_ids_0 + sep + token_ids_1 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):