always_truncate by default
This commit is contained in:
@@ -232,23 +232,6 @@ class CommonTestCases:
|
|||||||
assert len(truncated_sequence) == total_length - 2
|
assert len(truncated_sequence) == total_length - 2
|
||||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
||||||
|
|
||||||
def test_always_truncate(self):
|
|
||||||
tokenizer = self.get_tokenizer()
|
|
||||||
|
|
||||||
seq_0 = "This is a sentence to be encoded."
|
|
||||||
length_single_sequence = len(tokenizer.encode(seq_0))
|
|
||||||
length = len(tokenizer.encode(seq_0, seq_0, add_special_tokens=True))
|
|
||||||
|
|
||||||
not_truncated = tokenizer.encode(seq_0, seq_0, add_special_tokens=True, max_length=length_single_sequence)
|
|
||||||
truncated = tokenizer.encode(
|
|
||||||
seq_0, seq_0,
|
|
||||||
max_length=length_single_sequence,
|
|
||||||
add_special_tokens=True,
|
|
||||||
always_truncate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
assert truncated == not_truncated[:length_single_sequence - length]
|
|
||||||
|
|
||||||
def test_maximum_encoding_length_pair_input(self):
|
def test_maximum_encoding_length_pair_input(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
@@ -329,7 +312,6 @@ class CommonTestCases:
|
|||||||
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
|
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
|
||||||
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
|
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
|
||||||
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
assert len(sequence_ids) == len(encoded_sequence_w_special)
|
||||||
print(sequence_ids_orig, sequence_ids)
|
|
||||||
assert sequence_ids_orig == sequence_ids
|
assert sequence_ids_orig == sequence_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -700,7 +700,6 @@ class PreTrainedTokenizer(object):
|
|||||||
stride=0,
|
stride=0,
|
||||||
truncate_first_sequence=True,
|
truncate_first_sequence=True,
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
always_truncate=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||||
@@ -722,8 +721,6 @@ class PreTrainedTokenizer(object):
|
|||||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||||
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
||||||
will be truncated.
|
will be truncated.
|
||||||
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
|
|
||||||
sequences may be lost in the process.
|
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
@@ -735,7 +732,6 @@ class PreTrainedTokenizer(object):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
truncate_first_sequence=truncate_first_sequence,
|
truncate_first_sequence=truncate_first_sequence,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
always_truncate=always_truncate,
|
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
return encoded_inputs["input_ids"]
|
return encoded_inputs["input_ids"]
|
||||||
@@ -748,7 +744,6 @@ class PreTrainedTokenizer(object):
|
|||||||
stride=0,
|
stride=0,
|
||||||
truncate_first_sequence=True,
|
truncate_first_sequence=True,
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
always_truncate=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
|
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
|
||||||
@@ -769,8 +764,6 @@ class PreTrainedTokenizer(object):
|
|||||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||||
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
||||||
will be truncated.
|
will be truncated.
|
||||||
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
|
|
||||||
sequences may be lost in the process.
|
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
@@ -795,12 +788,10 @@ class PreTrainedTokenizer(object):
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
truncate_first_sequence=truncate_first_sequence,
|
truncate_first_sequence=truncate_first_sequence,
|
||||||
always_truncate=always_truncate,
|
|
||||||
return_tensors=return_tensors)
|
return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
|
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
|
||||||
truncate_first_sequence=True, always_truncate=False, return_tensors=None):
|
truncate_first_sequence=True, return_tensors=None):
|
||||||
"""
|
"""
|
||||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||||
It adds special tokens, truncates
|
It adds special tokens, truncates
|
||||||
@@ -820,8 +811,6 @@ class PreTrainedTokenizer(object):
|
|||||||
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
|
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
|
||||||
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
|
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
|
||||||
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
|
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
|
||||||
always_truncate: if set to True, will always truncate the sequences when overflowing, even if one of the
|
|
||||||
sequences may be lost in the process.
|
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
|
|
||||||
@@ -850,14 +839,9 @@ class PreTrainedTokenizer(object):
|
|||||||
if max_length:
|
if max_length:
|
||||||
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
|
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
|
||||||
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
|
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
|
||||||
if always_truncate:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
|
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
|
||||||
"This pair of sequences will be truncated but one of the sequences may not be present in the resulting list of ids.")
|
"This pair of sequences will be truncated with no regard to the special tokens")
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
|
|
||||||
"This pair of sequences will not be truncated.")
|
|
||||||
else:
|
else:
|
||||||
if n_added_tokens + len_ids + len_pair_ids > max_length:
|
if n_added_tokens + len_ids + len_pair_ids > max_length:
|
||||||
if truncate_first_sequence or not pair:
|
if truncate_first_sequence or not pair:
|
||||||
@@ -890,7 +874,7 @@ class PreTrainedTokenizer(object):
|
|||||||
encoded_inputs["input_ids"] = sequence
|
encoded_inputs["input_ids"] = sequence
|
||||||
encoded_inputs["token_type_ids"] = token_type_ids
|
encoded_inputs["token_type_ids"] = token_type_ids
|
||||||
|
|
||||||
if always_truncate and len(encoded_inputs["input_ids"]) > max_length:
|
if max_length and len(encoded_inputs["input_ids"]) > max_length:
|
||||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
||||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user