Max encoding length + corresponding tests
This commit is contained in:
@@ -211,3 +211,32 @@ class CommonTestCases:
|
||||
# Method is implemented (e.g. not GPT-2)
|
||||
if len(attached_sequences) != 2:
|
||||
assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - sum([len(seq) for seq in sequences])
|
||||
|
||||
def test_maximum_encoding_length_single_input(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
seq_0 = "This is a sentence to be encoded."
|
||||
|
||||
sequence = tokenizer.encode(seq_0)
|
||||
num_added_tokens = tokenizer.num_added_tokens()
|
||||
total_length = len(sequence) + num_added_tokens
|
||||
truncated_sequence = tokenizer.encode(seq_0, max_length=total_length - 2, add_special_tokens=True)
|
||||
|
||||
assert len(truncated_sequence) == total_length - 2
|
||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2])
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
seq_0 = "This is a sentence to be encoded."
|
||||
seq_1 = "This is another sentence to be encoded."
|
||||
|
||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||
truncated_second_sequence = tokenizer.add_special_tokens_sentences_pair(
|
||||
tokenizer.encode(seq_0),
|
||||
tokenizer.encode(seq_1)[:-2]
|
||||
)
|
||||
truncated_sequence = tokenizer.encode(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
|
||||
|
||||
assert len(truncated_sequence) == len(sequence) - 2
|
||||
assert truncated_sequence == truncated_second_sequence
|
||||
|
||||
@@ -693,7 +693,7 @@ class PreTrainedTokenizer(object):
|
||||
def _convert_token_to_id(self, token):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs):
|
||||
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
|
||||
"""
|
||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
|
||||
@@ -706,20 +706,36 @@ class PreTrainedTokenizer(object):
|
||||
to their model.
|
||||
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
|
||||
and 1 for the second.
|
||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
if max_length:
|
||||
sequence_tokens = sequence_tokens[:max_length - self.num_added_tokens()]
|
||||
return self.add_special_tokens_single_sentence(sequence_tokens)
|
||||
else:
|
||||
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
return ids[:max_length] if max_length != -1 else ids
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||
|
||||
if add_special_tokens:
|
||||
if max_length:
|
||||
if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length:
|
||||
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||
else:
|
||||
if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length:
|
||||
second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)]
|
||||
|
||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
|
||||
else:
|
||||
if max_length:
|
||||
first_sentence_tokens = first_sentence_tokens[:max_length]
|
||||
second_sentence_tokens = second_sentence_tokens[:max_length]
|
||||
|
||||
if output_mask:
|
||||
logger.warning("Can't output mask if you're not joining two sequences.")
|
||||
return first_sentence_tokens, second_sentence_tokens
|
||||
|
||||
Reference in New Issue
Block a user