Max encoding length + corresponding tests

This commit is contained in:
LysandreJik
2019-09-11 18:20:24 +02:00
parent c4d4f3ec8c
commit af23b626c8
2 changed files with 48 additions and 3 deletions

View File

@@ -693,7 +693,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs):
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@@ -706,20 +706,36 @@ class PreTrainedTokenizer(object):
to their model.
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
**kwargs: passed to the `self.tokenize()` method
"""
if text_pair is None:
if add_special_tokens:
return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text, **kwargs)))
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
sequence_tokens = sequence_tokens[:max_length - self.num_added_tokens()]
return self.add_special_tokens_single_sentence(sequence_tokens)
else:
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
return ids[:max_length] if max_length != -1 else ids
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens:
if max_length:
if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length:
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
else:
if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length:
second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)]
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
else:
if max_length:
first_sentence_tokens = first_sentence_tokens[:max_length]
second_sentence_tokens = second_sentence_tokens[:max_length]
if output_mask:
logger.warning("Can't output mask if you're not joining two sequences.")
return first_sentence_tokens, second_sentence_tokens