From baa74326ab4d0443ae22f998e494d6306ec738d8 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 19 Sep 2019 10:42:32 +0200 Subject: [PATCH] Stride + tests + small fixes --- .../tests/tokenization_tests_commons.py | 6 ++++-- pytorch_transformers/tokenization_distilbert.py | 1 - pytorch_transformers/tokenization_utils.py | 11 +++++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index e1c5ccb10a..65a2938378 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -217,16 +217,18 @@ class CommonTestCases: tokenizer = self.get_tokenizer() seq_0 = "This is a sentence to be encoded." + stride = 2 sequence = tokenizer.encode(seq_0) num_added_tokens = tokenizer.num_added_tokens() total_length = len(sequence) + num_added_tokens - information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True) + information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride) truncated_sequence = information["sequence"] overflowing_tokens = information["overflowing_tokens"] - assert len(overflowing_tokens) == 2 + assert len(overflowing_tokens) == 2 + stride + assert overflowing_tokens == sequence[-(2 + stride):] assert len(truncated_sequence) == total_length - 2 assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2]) diff --git a/pytorch_transformers/tokenization_distilbert.py b/pytorch_transformers/tokenization_distilbert.py index cb716594a2..0af782beb1 100644 --- a/pytorch_transformers/tokenization_distilbert.py +++ b/pytorch_transformers/tokenization_distilbert.py @@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer): | first sequence | second sequence """ sep = [self.sep_token_id] - cls = [self.cls_token_id] return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1] diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 490b2611e1..c4a9c71917 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -722,7 +722,7 @@ class PreTrainedTokenizer(object): logger.warning("No special tokens were added. The two sequences have been concatenated.") return first_sentence_tokens + second_sentence_tokens - def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): + def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs): """ Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified. @@ -735,6 +735,9 @@ class PreTrainedTokenizer(object): output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence, and 1 for the second. max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. + If there are overflowing tokens, those will be added to the returned dictionary + stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defined the number of additional tokens. **kwargs: passed to the `self.tokenize()` method """ @@ -745,13 +748,13 @@ class PreTrainedTokenizer(object): if add_special_tokens: sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if max_length: - information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:] + information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:] sequence_tokens = sequence_tokens[:max_length - n_added_tokens] sequence = self.add_special_tokens_single_sequence(sequence_tokens) else: sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if max_length: - information["overflowing_tokens"] = sequence_tokens[max_length:] + information["overflowing_tokens"] = sequence_tokens[max_length - stride:] sequence_tokens = sequence_tokens[:max_length] sequence = sequence_tokens @@ -788,7 +791,7 @@ class PreTrainedTokenizer(object): sequence = first_sentence_tokens + second_sentence_tokens if max_length: - information["overflowing_tokens"] = sequence[max_length:] + information["overflowing_tokens"] = sequence[max_length - stride:] sequence = sequence[:max_length] if output_mask: information["mask"] = [0] * len(sequence)