From c10c7d59e7a306609c44d1523346dbbc3fd7121d Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 19 Sep 2019 10:13:10 +0200 Subject: [PATCH] Mask computing in standalone method. Tests. --- .../tests/tokenization_tests_commons.py | 24 +++++++++---------- pytorch_transformers/tokenization_bert.py | 12 ++++++++++ .../tokenization_distilbert.py | 22 ++++++++++------- pytorch_transformers/tokenization_roberta.py | 12 ++++++++++ pytorch_transformers/tokenization_utils.py | 8 +++++-- pytorch_transformers/tokenization_xlm.py | 12 ++++++++++ pytorch_transformers/tokenization_xlnet.py | 14 ++++++++++- 7 files changed, 81 insertions(+), 23 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index e16891a5db..e1c5ccb10a 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -187,18 +187,18 @@ class CommonTestCases: for weights_list_2 in weights_lists_2: self.assertListEqual(weights_list, weights_list_2) - # def test_mask_output(self): - # if sys.version_info <= (3, 0): - # return - # - # tokenizer = self.get_tokenizer() - # - # if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": - # seq_0 = "Test this method." - # seq_1 = "With these inputs." - # information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True) - # sequences, mask = information["sequence"], information["mask"] - # assert len(sequences) == len(mask) + def test_mask_output(self): + if sys.version_info <= (3, 0): + return + + tokenizer = self.get_tokenizer() + + if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": + seq_0 = "Test this method." + seq_1 = "With these inputs." + information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True) + sequences, mask = information["sequence"], information["mask"] + assert len(sequences) == len(mask) def test_number_of_added_tokens(self): tokenizer = self.get_tokenizer() diff --git a/pytorch_transformers/tokenization_bert.py b/pytorch_transformers/tokenization_bert.py index 0aae2d6a5f..211b8fe93b 100644 --- a/pytorch_transformers/tokenization_bert.py +++ b/pytorch_transformers/tokenization_bert.py @@ -204,6 +204,18 @@ class BertTokenizer(PreTrainedTokenizer): return cls + token_ids_0 + sep + token_ids_1 + sep + def create_mask_from_sequences(self, sequence_0, sequence_1): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + return len(cls + self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1) + sep) * [1] + def save_vocabulary(self, vocab_path): """Save the tokenizer vocabulary to a directory or file.""" index = 0 diff --git a/pytorch_transformers/tokenization_distilbert.py b/pytorch_transformers/tokenization_distilbert.py index ce2ace35d8..cb716594a2 100644 --- a/pytorch_transformers/tokenization_distilbert.py +++ b/pytorch_transformers/tokenization_distilbert.py @@ -64,12 +64,18 @@ class DistilBertTokenizer(BertTokenizer): def add_special_tokens_single_sequence(self, token_ids): return token_ids - def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1, output_mask=False): + def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): sep = [self.sep_token_id] - if output_mask: - return ( - token_ids_0 + sep + token_ids_1, - [0] * len(token_ids_0 + sep) + [1] * len(token_ids_1) - ) - else: - return token_ids_0 + sep + token_ids_1 + return token_ids_0 + sep + token_ids_1 + + def create_mask_from_sequences(self, sequence_0, sequence_1): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1] diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index b39904711a..161c4e6870 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -96,3 +96,15 @@ class RobertaTokenizer(GPT2Tokenizer): sep = [self.sep_token_id] cls = [self.cls_token_id] return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def create_mask_from_sequences(self, sequence_0, sequence_1): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A RoBERTa sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + return len(cls + self.encode(sequence_0) + sep + sep) * [0] + len(self.encode(sequence_1) + sep) * [1] \ No newline at end of file diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 185ca09576..490b2611e1 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -779,8 +779,8 @@ class PreTrainedTokenizer(object): second_sentence_tokens ) - # if output_mask: - # sequence, information["mask"] = encoded_sequence + if output_mask: + information["mask"] = self.create_mask_from_sequences(text, text_pair) information["sequence"] = sequence else: @@ -797,6 +797,10 @@ class PreTrainedTokenizer(object): return information + def create_mask_from_sequences(self, sequence_0, sequence_1): + logger.warning("This tokenizer does not make use of special tokens.") + return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1)) + def add_special_tokens_single_sequence(self, token_ids): logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") return token_ids diff --git a/pytorch_transformers/tokenization_xlm.py b/pytorch_transformers/tokenization_xlm.py index 5a29944d10..3b177b6a17 100644 --- a/pytorch_transformers/tokenization_xlm.py +++ b/pytorch_transformers/tokenization_xlm.py @@ -770,6 +770,18 @@ class XLMTokenizer(PreTrainedTokenizer): cls = [self.cls_token_id] return cls + token_ids_0 + sep + token_ids_1 + sep + def create_mask_from_sequences(self, sequence_0, sequence_1): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + An XLM sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + return len(cls + self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1) + sep) * [1] + def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): diff --git a/pytorch_transformers/tokenization_xlnet.py b/pytorch_transformers/tokenization_xlnet.py index fecc35b070..1464743759 100644 --- a/pytorch_transformers/tokenization_xlnet.py +++ b/pytorch_transformers/tokenization_xlnet.py @@ -198,9 +198,21 @@ class XLNetTokenizer(PreTrainedTokenizer): sep = [self.sep_token_id] cls = [self.cls_token_id] - cls_segment_ids = [2] return token_ids_0 + sep + token_ids_1 + sep + cls + def create_mask_from_sequences(self, sequence_0, sequence_1): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 2 + | first sequence | second sequence | CLS segment ID + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + cls_segment_id = [2] + + return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1) + sep) * [1] + cls_segment_id + def save_vocabulary(self, save_directory): """ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.