From e391d4735eb02023dd74fb0f3f751d9acc8b6c7f Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Mon, 2 Sep 2019 16:42:32 -0400 Subject: [PATCH] Tokenizers' encode function can output binary masks --- pytorch_transformers/tokenization_bert.py | 10 ++++++++-- pytorch_transformers/tokenization_roberta.py | 10 ++++++++-- pytorch_transformers/tokenization_utils.py | 8 +++++--- pytorch_transformers/tokenization_xlm.py | 11 +++++++++-- pytorch_transformers/tokenization_xlnet.py | 11 +++++++++-- 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index b85a4ccf9c..f720d530f8 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -194,14 +194,20 @@ class BertTokenizer(PreTrainedTokenizer): """ return [self.cls_token_id] + token_ids + [self.sep_token_id] - def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): + def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): """ Adds special tokens to a sequence pair for sequence classification tasks. A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] """ sep = [self.sep_token_id] cls = [self.cls_token_id] - return cls + token_ids_0 + sep + token_ids_1 + sep + if output_mask: + return ( + cls + token_ids_0 + sep + token_ids_1 + sep, + [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep) + ) + else: + return cls + token_ids_0 + sep + token_ids_1 + sep def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary to a directory or file.""" diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index 67808752d5..db4cbe967f 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -88,11 +88,17 @@ class RobertaTokenizer(GPT2Tokenizer): """ return [self.cls_token_id] + token_ids + [self.sep_token_id] - def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): + def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): """ Adds special tokens to a sequence pair for sequence classification tasks. A RoBERTa sequence pair has the following format: A B """ sep = [self.sep_token_id] cls = [self.cls_token_id] - return cls + token_ids_0 + sep + sep + token_ids_1 + sep + if output_mask: + return ( + cls + token_ids_0 + sep + sep + token_ids_1 + sep, + [0] * len(cls + token_ids_0 + sep) + [1] * len(sep + token_ids_1 + sep) + ) + else: + return cls + token_ids_0 + sep + sep + token_ids_1 + sep diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 1e2cd59648..22ab04ac6a 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -663,7 +663,7 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): + def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. @@ -674,6 +674,8 @@ class PreTrainedTokenizer(object): text_pair: Optional second sequence to be encoded. add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative to their model. + output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence, + and 1 for the second. **kwargs: passed to the `self.tokenize()` method """ if text_pair is None: @@ -686,7 +688,7 @@ class PreTrainedTokenizer(object): second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if add_special_tokens: - return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) + return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask) else: return first_sentence_tokens, second_sentence_tokens @@ -694,7 +696,7 @@ class PreTrainedTokenizer(object): logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") return token_ids - def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): + def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") return token_ids_0 + token_ids_1 diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index f7231384b3..167860f15f 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -761,14 +761,21 @@ class XLMTokenizer(PreTrainedTokenizer): """ return [self.cls_token_id] + token_ids + [self.sep_token_id] - def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): + def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): """ Adds special tokens to a sequence pair for sequence classification tasks. An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] """ sep = [self.sep_token_id] cls = [self.cls_token_id] - return cls + token_ids_0 + sep + token_ids_1 + sep + + if output_mask: + return ( + cls + token_ids_0 + sep + token_ids_1 + sep, + [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep) + ) + else: + return cls + token_ids_0 + sep + token_ids_1 + sep def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index 230095daa9..cbf312bd38 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -190,14 +190,21 @@ class XLNetTokenizer(PreTrainedTokenizer): cls = [self.cls_token_id] return token_ids + sep + cls - def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): + def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): """ Adds special tokens to a sequence for sequence classification tasks. An XLNet sequence has the following format: X [SEP][CLS] """ + sep = [self.sep_token_id] cls = [self.cls_token_id] - return token_ids_0 + sep + token_ids_1 + sep + cls + if output_mask: + return ( + token_ids_0 + sep + token_ids_1 + sep + cls, + [0] * len(token_ids_0 + sep) + [1] * len(token_ids_1 + sep + cls) + ) + else: + return token_ids_0 + sep + token_ids_1 + sep + cls def save_vocabulary(self, save_directory): """ Save the sentencepiece vocabulary (copy original file) and special tokens file