From af23b626c8dcf16827dae08a9ba5ed3b8a8d6d97 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Wed, 11 Sep 2019 18:20:24 +0200 Subject: [PATCH] Max encoding length + corresponding tests --- .../tests/tokenization_tests_commons.py | 29 +++++++++++++++++++ pytorch_transformers/tokenization_utils.py | 22 ++++++++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 7500741cee..e077333c19 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -211,3 +211,32 @@ class CommonTestCases: # Method is implemented (e.g. not GPT-2) if len(attached_sequences) != 2: assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences]) + + def test_maximum_encoding_length_single_input(self): + tokenizer = self.get_tokenizer() + + seq_0 = "This is a sentence to be encoded." + + sequence = tokenizer.encode(seq_0) + num_added_tokens = tokenizer.num_added_tokens() + total_length = len(sequence) + num_added_tokens + truncated_sequence = tokenizer.encode(seq_0, max_length=total_length - 2, add_special_tokens=True) + + assert len(truncated_sequence) == total_length - 2 + assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2]) + + def test_maximum_encoding_length_pair_input(self): + tokenizer = self.get_tokenizer() + + seq_0 = "This is a sentence to be encoded." + seq_1 = "This is another sentence to be encoded." + + sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) + truncated_second_sequence = tokenizer.add_special_tokens_sentences_pair( + tokenizer.encode(seq_0), + tokenizer.encode(seq_1)[:-2] + ) + truncated_sequence = tokenizer.encode(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True) + + assert len(truncated_sequence) == len(sequence) - 2 + assert truncated_sequence == truncated_second_sequence diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index a22f15fa3e..97cf242511 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -693,7 +693,7 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs): + def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. @@ -706,20 +706,36 @@ class PreTrainedTokenizer(object): to their model. output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence, and 1 for the second. + max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. **kwargs: passed to the `self.tokenize()` method """ if text_pair is None: if add_special_tokens: - return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs))) + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + if max_length: + sequence_tokens = sequence_tokens[:max_length - self.num_added_tokens()] + return self.add_special_tokens_single_sentence(sequence_tokens) else: - return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + return ids[:max_length] if max_length != -1 else ids first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if add_special_tokens: + if max_length: + if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length: + logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.") + else: + if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length: + second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)] + return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask) else: + if max_length: + first_sentence_tokens = first_sentence_tokens[:max_length] + second_sentence_tokens = second_sentence_tokens[:max_length] + if output_mask: logger.warning("Can't output mask if you're not joining two sequences.") return first_sentence_tokens, second_sentence_tokens