Stride + tests + small fixes

This commit is contained in:
LysandreJik
2019-09-19 10:42:32 +02:00
parent c10c7d59e7
commit baa74326ab
3 changed files with 11 additions and 7 deletions

View File

@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
"""
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
**kwargs: passed to the `self.tokenize()` method
"""
@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length:]
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
sequence_tokens = sequence_tokens[:max_length]
sequence = sequence_tokens
@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
sequence = first_sentence_tokens + second_sentence_tokens
if max_length:
information["overflowing_tokens"] = sequence[max_length:]
information["overflowing_tokens"] = sequence[max_length - stride:]
sequence = sequence[:max_length]
if output_mask:
information["mask"] = [0] * len(sequence)