Stride + tests + small fixes
This commit is contained in:
@@ -217,16 +217,18 @@ class CommonTestCases:
|
|||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
seq_0 = "This is a sentence to be encoded."
|
seq_0 = "This is a sentence to be encoded."
|
||||||
|
stride = 2
|
||||||
|
|
||||||
sequence = tokenizer.encode(seq_0)
|
sequence = tokenizer.encode(seq_0)
|
||||||
num_added_tokens = tokenizer.num_added_tokens()
|
num_added_tokens = tokenizer.num_added_tokens()
|
||||||
total_length = len(sequence) + num_added_tokens
|
total_length = len(sequence) + num_added_tokens
|
||||||
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True)
|
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
|
||||||
|
|
||||||
truncated_sequence = information["sequence"]
|
truncated_sequence = information["sequence"]
|
||||||
overflowing_tokens = information["overflowing_tokens"]
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
|
|
||||||
assert len(overflowing_tokens) == 2
|
assert len(overflowing_tokens) == 2 + stride
|
||||||
|
assert overflowing_tokens == sequence[-(2 + stride):]
|
||||||
assert len(truncated_sequence) == total_length - 2
|
assert len(truncated_sequence) == total_length - 2
|
||||||
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,5 @@ class DistilBertTokenizer(BertTokenizer):
|
|||||||
| first sequence | second sequence
|
| first sequence | second sequence
|
||||||
"""
|
"""
|
||||||
sep = [self.sep_token_id]
|
sep = [self.sep_token_id]
|
||||||
cls = [self.cls_token_id]
|
|
||||||
|
|
||||||
return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1]
|
return len(self.encode(sequence_0) + sep) * [0] + len(self.encode(sequence_1)) * [1]
|
||||||
|
|||||||
@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
return first_sentence_tokens + second_sentence_tokens
|
return first_sentence_tokens + second_sentence_tokens
|
||||||
|
|
||||||
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
|
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
||||||
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
||||||
@@ -735,6 +735,9 @@ class PreTrainedTokenizer(object):
|
|||||||
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
|
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
|
||||||
and 1 for the second.
|
and 1 for the second.
|
||||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||||
|
If there are overflowing tokens, those will be added to the returned dictionary
|
||||||
|
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||||
|
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -745,13 +748,13 @@ class PreTrainedTokenizer(object):
|
|||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
if max_length:
|
if max_length:
|
||||||
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
|
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
|
||||||
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
||||||
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
|
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
|
||||||
else:
|
else:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
if max_length:
|
if max_length:
|
||||||
information["overflowing_tokens"] = sequence_tokens[max_length:]
|
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
|
||||||
sequence_tokens = sequence_tokens[:max_length]
|
sequence_tokens = sequence_tokens[:max_length]
|
||||||
sequence = sequence_tokens
|
sequence = sequence_tokens
|
||||||
|
|
||||||
@@ -788,7 +791,7 @@ class PreTrainedTokenizer(object):
|
|||||||
sequence = first_sentence_tokens + second_sentence_tokens
|
sequence = first_sentence_tokens + second_sentence_tokens
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
information["overflowing_tokens"] = sequence[max_length:]
|
information["overflowing_tokens"] = sequence[max_length - stride:]
|
||||||
sequence = sequence[:max_length]
|
sequence = sequence[:max_length]
|
||||||
if output_mask:
|
if output_mask:
|
||||||
information["mask"] = [0] * len(sequence)
|
information["mask"] = [0] * len(sequence)
|
||||||
|
|||||||
Reference in New Issue
Block a user