From 66ea76b8a9e867494993a96a671a94735d99a317 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 19 Sep 2019 13:50:51 +0200 Subject: [PATCH] prepare_for_model and prepare_pair_for_model methods. Added an option to select which sequence will be truncated. --- .../tests/tokenization_tests_commons.py | 15 +++- pytorch_transformers/tokenization_utils.py | 79 +++++++++++++------ 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 65a2938378..5da19bb660 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -237,16 +237,29 @@ class CommonTestCases: seq_0 = "This is a sentence to be encoded." seq_1 = "This is another sentence to be encoded." + stride = 2 + + sequence_0_no_special_tokens = tokenizer.encode(seq_0) + sequence_1_no_special_tokens = tokenizer.encode(seq_1) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair( tokenizer.encode(seq_0), tokenizer.encode(seq_1)[:-2] ) - information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True) + + information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, + stride=stride) + information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, + add_special_tokens=True, stride=stride, + truncate_second_sequence_first=False) truncated_sequence = information["sequence"] overflowing_tokens = information["overflowing_tokens"] + overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"] + assert len(overflowing_tokens) == 2 + stride + assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):] + assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):] assert len(truncated_sequence) == len(sequence) - 2 assert truncated_sequence == truncated_second_sequence diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index c4a9c71917..0b23d00b8a 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -722,7 +722,15 @@ class PreTrainedTokenizer(object): logger.warning("No special tokens were added. The two sequences have been concatenated.") return first_sentence_tokens + second_sentence_tokens - def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs): + def encode_plus(self, + text, + text_pair=None, + add_special_tokens=False, + output_mask=False, + max_length=None, + stride=0, + truncate_second_sequence_first=True, + **kwargs): """ Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. @@ -738,54 +746,40 @@ class PreTrainedTokenizer(object): If there are overflowing tokens, those will be added to the returned dictionary stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defined the number of additional tokens. + truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence + will be truncated. **kwargs: passed to the `self.tokenize()` method """ information = {} if text_pair is None: - n_added_tokens = self.num_added_tokens() + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if add_special_tokens: - sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) - if max_length: - information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:] - sequence_tokens = sequence_tokens[:max_length - n_added_tokens] - sequence = self.add_special_tokens_single_sequence(sequence_tokens) + information = self.prepare_for_model(sequence_tokens, max_length, stride) else: - sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if max_length: information["overflowing_tokens"] = sequence_tokens[max_length - stride:] sequence_tokens = sequence_tokens[:max_length] - sequence = sequence_tokens + information["sequence"] = sequence_tokens if output_mask: - information["mask"] = [0] * len(sequence) - - information["sequence"] = sequence + information["mask"] = [0] * len(information["sequence"]) else: first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] - f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens) - n_added_tokens = self.num_added_tokens(pair=True) if add_special_tokens: - if max_length: - if len(first_sentence_tokens) + n_added_tokens >= max_length: - logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.") - else: - if f_len + s_len + self.num_added_tokens(pair=True) > max_length: - information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:] - second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens] - - sequence = self.add_special_tokens_sequence_pair( + information = self.prepare_pair_for_model( first_sentence_tokens, - second_sentence_tokens + second_sentence_tokens, + max_length, + truncate_second_sequence_first, + stride ) if output_mask: information["mask"] = self.create_mask_from_sequences(text, text_pair) - - information["sequence"] = sequence else: logger.warning("No special tokens were added. The two sequences have been concatenated.") sequence = first_sentence_tokens + second_sentence_tokens @@ -800,6 +794,39 @@ class PreTrainedTokenizer(object): return information + def prepare_for_model(self, ids, max_length=None, stride=0): + information = {} + n_added_tokens = self.num_added_tokens() + if max_length: + information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:] + ids = ids[:max_length - n_added_tokens] + information["sequence"] = self.add_special_tokens_single_sequence(ids) + + return information + + def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0): + f_len, s_len = len(ids_0), len(ids_1) + n_added_tokens = self.num_added_tokens(pair=True) + information = {} + + if max_length: + if len(ids_0) + n_added_tokens >= max_length: + logger.warning( + "The first sequence is longer than the maximum specified length. This sequence will not be truncated.") + else: + if f_len + s_len + self.num_added_tokens(pair=True) > max_length: + if truncate_second_sequence_first: + information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:] + ids_1 = ids_1[:max_length - f_len - n_added_tokens] + else: + information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:] + ids_0 = ids_0[:max_length - s_len - n_added_tokens] + + sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1) + information["sequence"] = sequence + + return information + def create_mask_from_sequences(self, sequence_0, sequence_1): logger.warning("This tokenizer does not make use of special tokens.") return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))