From 7c789c337d9a4f6e9e904d3c1351c28a94183894 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Mon, 30 Sep 2019 10:20:14 -0400 Subject: [PATCH] Always truncate argument in the encode method --- .../tests/tokenization_tests_commons.py | 17 ++++++++ transformers/tokenization_utils.py | 43 +++++++++++++------ 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index b71ba44436..73a2cb44b3 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -232,6 +232,23 @@ class CommonTestCases: assert len(truncated_sequence) == total_length - 2 assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2]) + def test_always_truncate(self): + tokenizer = self.get_tokenizer() + + seq_0 = "This is a sentence to be encoded." + length_single_sequence = len(tokenizer.encode(seq_0)) + length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True)) + + not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence) + truncated = tokenizer.encode( + seq_0, seq_0, + max_length=length_single_sequence, + add_special_tokens=True, + always_truncate=True + ) + + assert truncated == not_truncated[:length_single_sequence - length] + def test_maximum_encoding_length_pair_input(self): tokenizer = self.get_tokenizer() diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index db9e9cd72e..fc7fe1df67 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -693,14 +693,15 @@ class PreTrainedTokenizer(object): raise NotImplementedError def encode(self, - text, - text_pair=None, - add_special_tokens=False, - max_length=None, - stride=0, - truncate_first_sequence=True, - return_tensors=None, - **kwargs): + text, + text_pair=None, + add_special_tokens=False, + max_length=None, + stride=0, + truncate_first_sequence=True, + return_tensors=None, + always_truncate=False, + **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. @@ -721,6 +722,8 @@ class PreTrainedTokenizer(object): from the main sequence returned. The value of this argument defined the number of additional tokens. truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence will be truncated. + always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the + sequences may be lost in the process. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. **kwargs: passed to the `self.tokenize()` method @@ -732,6 +735,7 @@ class PreTrainedTokenizer(object): stride=stride, truncate_first_sequence=truncate_first_sequence, return_tensors=return_tensors, + always_truncate=always_truncate, **kwargs) return encoded_inputs["input_ids"] @@ -744,6 +748,7 @@ class PreTrainedTokenizer(object): stride=0, truncate_first_sequence=True, return_tensors=None, + always_truncate=False, **kwargs): """ Returns a dictionary containing the encoded sequence or sequence pair and additional informations: @@ -764,6 +769,8 @@ class PreTrainedTokenizer(object): from the main sequence returned. The value of this argument defined the number of additional tokens. truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence will be truncated. + always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the + sequences may be lost in the process. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. **kwargs: passed to the `self.tokenize()` method @@ -788,11 +795,12 @@ class PreTrainedTokenizer(object): add_special_tokens=add_special_tokens, stride=stride, truncate_first_sequence=truncate_first_sequence, + always_truncate=always_truncate, return_tensors=return_tensors) def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, - truncate_first_sequence=True, return_tensors=None): + truncate_first_sequence=True, always_truncate=False, return_tensors=None): """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It adds special tokens, truncates @@ -812,6 +820,8 @@ class PreTrainedTokenizer(object): truncate_first_sequence: if set to `True` and an optional second list of input ids is provided, alongside a specified `max_length`, will truncate the first sequence if the total size is superior than the specified `max_length`. If set to `False`, will truncate the second sequence instead. + always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the + sequences may be lost in the process. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. @@ -826,9 +836,14 @@ class PreTrainedTokenizer(object): if max_length: n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0 if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length: - logger.warning( - "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length." - "This pair of sequences will not be truncated.") + if always_truncate: + logger.warning( + "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. " + "This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids.") + else: + logger.warning( + "You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. " + "This pair of sequences will not be truncated.") else: if n_added_tokens + len_ids + len_pair_ids > max_length: if truncate_first_sequence or not pair: @@ -860,6 +875,10 @@ class PreTrainedTokenizer(object): encoded_inputs["input_ids"] = sequence encoded_inputs["token_type_ids"] = token_type_ids + if always_truncate and len(encoded_inputs["input_ids"]) > max_length: + encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length] + encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] + return encoded_inputs def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):