prepare_for_model and prepare_pair_for_model methods. Added an option to select which sequence will be truncated.
This commit is contained in:
@@ -237,16 +237,29 @@ class CommonTestCases:
|
||||
|
||||
seq_0 = "This is a sentence to be encoded."
|
||||
seq_1 = "This is another sentence to be encoded."
|
||||
stride = 2
|
||||
|
||||
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
|
||||
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
|
||||
|
||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
|
||||
tokenizer.encode(seq_0),
|
||||
tokenizer.encode(seq_1)[:-2]
|
||||
)
|
||||
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
|
||||
|
||||
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
|
||||
stride=stride)
|
||||
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
|
||||
add_special_tokens=True, stride=stride,
|
||||
truncate_second_sequence_first=False)
|
||||
|
||||
truncated_sequence = information["sequence"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
|
||||
|
||||
assert len(overflowing_tokens) == 2 + stride
|
||||
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
|
||||
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
|
||||
assert len(truncated_sequence) == len(sequence) - 2
|
||||
assert truncated_sequence == truncated_second_sequence
|
||||
|
||||
@@ -722,7 +722,15 @@ class PreTrainedTokenizer(object):
|
||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||
return first_sentence_tokens + second_sentence_tokens
|
||||
|
||||
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
|
||||
def encode_plus(self,
|
||||
text,
|
||||
text_pair=None,
|
||||
add_special_tokens=False,
|
||||
output_mask=False,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncate_second_sequence_first=True,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
||||
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
||||
@@ -738,54 +746,40 @@ class PreTrainedTokenizer(object):
|
||||
If there are overflowing tokens, those will be added to the returned dictionary
|
||||
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||
truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence
|
||||
will be truncated.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
|
||||
information = {}
|
||||
|
||||
if text_pair is None:
|
||||
n_added_tokens = self.num_added_tokens()
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
if add_special_tokens:
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
if max_length:
|
||||
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
|
||||
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
||||
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
|
||||
information = self.prepare_for_model(sequence_tokens, max_length, stride)
|
||||
else:
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
if max_length:
|
||||
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
|
||||
sequence_tokens = sequence_tokens[:max_length]
|
||||
sequence = sequence_tokens
|
||||
information["sequence"] = sequence_tokens
|
||||
|
||||
if output_mask:
|
||||
information["mask"] = [0] * len(sequence)
|
||||
|
||||
information["sequence"] = sequence
|
||||
information["mask"] = [0] * len(information["sequence"])
|
||||
else:
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
|
||||
n_added_tokens = self.num_added_tokens(pair=True)
|
||||
|
||||
if add_special_tokens:
|
||||
if max_length:
|
||||
if len(first_sentence_tokens) + n_added_tokens >= max_length:
|
||||
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||
else:
|
||||
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
||||
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
|
||||
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
|
||||
|
||||
sequence = self.add_special_tokens_sequence_pair(
|
||||
information = self.prepare_pair_for_model(
|
||||
first_sentence_tokens,
|
||||
second_sentence_tokens
|
||||
second_sentence_tokens,
|
||||
max_length,
|
||||
truncate_second_sequence_first,
|
||||
stride
|
||||
)
|
||||
|
||||
if output_mask:
|
||||
information["mask"] = self.create_mask_from_sequences(text, text_pair)
|
||||
|
||||
information["sequence"] = sequence
|
||||
else:
|
||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||
sequence = first_sentence_tokens + second_sentence_tokens
|
||||
@@ -800,6 +794,39 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
return information
|
||||
|
||||
def prepare_for_model(self, ids, max_length=None, stride=0):
|
||||
information = {}
|
||||
n_added_tokens = self.num_added_tokens()
|
||||
if max_length:
|
||||
information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
|
||||
ids = ids[:max_length - n_added_tokens]
|
||||
information["sequence"] = self.add_special_tokens_single_sequence(ids)
|
||||
|
||||
return information
|
||||
|
||||
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0):
|
||||
f_len, s_len = len(ids_0), len(ids_1)
|
||||
n_added_tokens = self.num_added_tokens(pair=True)
|
||||
information = {}
|
||||
|
||||
if max_length:
|
||||
if len(ids_0) + n_added_tokens >= max_length:
|
||||
logger.warning(
|
||||
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||
else:
|
||||
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
||||
if truncate_second_sequence_first:
|
||||
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
|
||||
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
|
||||
else:
|
||||
information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
|
||||
ids_0 = ids_0[:max_length - s_len - n_added_tokens]
|
||||
|
||||
sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
|
||||
information["sequence"] = sequence
|
||||
|
||||
return information
|
||||
|
||||
def create_mask_from_sequences(self, sequence_0, sequence_1):
|
||||
logger.warning("This tokenizer does not make use of special tokens.")
|
||||
return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))
|
||||
|
||||
Reference in New Issue
Block a user