From cc412edd42102e6e068b7cbb34d0cf11856d6ae8 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Mon, 30 Sep 2019 14:11:41 -0400 Subject: [PATCH] Supports already existing special tokens --- transformers/tests/tokenization_tests_commons.py | 12 ++++++++++++ transformers/tokenization_bert.py | 6 +++++- transformers/tokenization_roberta.py | 6 +++++- transformers/tokenization_utils.py | 2 +- transformers/tokenization_xlm.py | 6 +++++- transformers/tokenization_xlnet.py | 6 +++++- 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index f44be1f97b..d656e165aa 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -321,4 +321,16 @@ class CommonTestCases: filtered_sequence = [x for x in filtered_sequence if x is not None] assert encoded_sequence == filtered_sequence + # Testing with already existing special tokens + if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id: + tokenizer.add_special_tokens({'cls_token': '', 'sep_token': ''}) + encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + sequence_ids_orig = encoded_sequence_dict["sequence_ids"] + sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True) + assert len(sequence_ids) == len(encoded_sequence_w_special) + print(sequence_ids_orig, sequence_ids) + assert sequence_ids_orig == sequence_ids + + diff --git a/transformers/tokenization_bert.py b/transformers/tokenization_bert.py index b05f88b7a8..72720bfa4e 100644 --- a/transformers/tokenization_bert.py +++ b/transformers/tokenization_bert.py @@ -204,7 +204,7 @@ class BertTokenizer(PreTrainedTokenizer): return cls + token_ids_0 + sep + token_ids_1 + sep - def get_sequence_ids(self, token_ids_0, token_ids_1=None): + def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -217,6 +217,10 @@ class BertTokenizer(PreTrainedTokenizer): Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ + + if special_tokens_present: + return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if token_ids_1: return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] else: diff --git a/transformers/tokenization_roberta.py b/transformers/tokenization_roberta.py index c371ae3904..6f19ee699c 100644 --- a/transformers/tokenization_roberta.py +++ b/transformers/tokenization_roberta.py @@ -100,7 +100,7 @@ class RobertaTokenizer(GPT2Tokenizer): cls = [self.cls_token_id] return cls + token_ids_0 + sep + sep + token_ids_1 + sep - def get_sequence_ids(self, token_ids_0, token_ids_1=None): + def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -113,6 +113,10 @@ class RobertaTokenizer(GPT2Tokenizer): Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ + + if special_tokens_present: + return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if token_ids_1: return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0] else: diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index 5aa7f55730..a0f1382aa2 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -908,7 +908,7 @@ class PreTrainedTokenizer(object): logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") return token_ids_0 + token_ids_1 - def get_sequence_ids(self, token_ids_0, token_ids_1=None): + def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False): return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) def convert_ids_to_tokens(self, ids, skip_special_tokens=False): diff --git a/transformers/tokenization_xlm.py b/transformers/tokenization_xlm.py index b6e27ee59e..3f398853dc 100644 --- a/transformers/tokenization_xlm.py +++ b/transformers/tokenization_xlm.py @@ -770,7 +770,7 @@ class XLMTokenizer(PreTrainedTokenizer): cls = [self.cls_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep - def get_sequence_ids(self, token_ids_0, token_ids_1=None): + def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -783,6 +783,10 @@ class XLMTokenizer(PreTrainedTokenizer): Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ + + if special_tokens_present: + return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if token_ids_1: return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] else: diff --git a/transformers/tokenization_xlnet.py b/transformers/tokenization_xlnet.py index c8d59decee..325a32ef6f 100644 --- a/transformers/tokenization_xlnet.py +++ b/transformers/tokenization_xlnet.py @@ -200,7 +200,7 @@ class XLNetTokenizer(PreTrainedTokenizer): cls = [self.cls_token_id] return token_ids_0 + sep + token_ids_1 + sep + cls - def get_sequence_ids(self, token_ids_0, token_ids_1=None): + def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False): """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. @@ -213,6 +213,10 @@ class XLNetTokenizer(PreTrainedTokenizer): Returns: A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. """ + + if special_tokens_present: + return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0)) + if token_ids_1: return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0] else: