prepare_for_model and prepare_pair_for_model methods. Added an option to select which sequence will be truncated.
This commit is contained in:
@@ -237,16 +237,29 @@ class CommonTestCases:
|
|||||||
|
|
||||||
seq_0 = "This is a sentence to be encoded."
|
seq_0 = "This is a sentence to be encoded."
|
||||||
seq_1 = "This is another sentence to be encoded."
|
seq_1 = "This is another sentence to be encoded."
|
||||||
|
stride = 2
|
||||||
|
|
||||||
|
sequence_0_no_special_tokens = tokenizer.encode(seq_0)
|
||||||
|
sequence_1_no_special_tokens = tokenizer.encode(seq_1)
|
||||||
|
|
||||||
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
|
||||||
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
|
truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
|
||||||
tokenizer.encode(seq_0),
|
tokenizer.encode(seq_0),
|
||||||
tokenizer.encode(seq_1)[:-2]
|
tokenizer.encode(seq_1)[:-2]
|
||||||
)
|
)
|
||||||
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True)
|
|
||||||
|
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
|
||||||
|
stride=stride)
|
||||||
|
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
|
||||||
|
add_special_tokens=True, stride=stride,
|
||||||
|
truncate_second_sequence_first=False)
|
||||||
|
|
||||||
truncated_sequence = information["sequence"]
|
truncated_sequence = information["sequence"]
|
||||||
overflowing_tokens = information["overflowing_tokens"]
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
|
overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
|
||||||
|
|
||||||
|
assert len(overflowing_tokens) == 2 + stride
|
||||||
|
assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
|
||||||
|
assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
|
||||||
assert len(truncated_sequence) == len(sequence) - 2
|
assert len(truncated_sequence) == len(sequence) - 2
|
||||||
assert truncated_sequence == truncated_second_sequence
|
assert truncated_sequence == truncated_second_sequence
|
||||||
|
|||||||
@@ -722,7 +722,15 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
return first_sentence_tokens + second_sentence_tokens
|
return first_sentence_tokens + second_sentence_tokens
|
||||||
|
|
||||||
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, stride=0, **kwargs):
|
def encode_plus(self,
|
||||||
|
text,
|
||||||
|
text_pair=None,
|
||||||
|
add_special_tokens=False,
|
||||||
|
output_mask=False,
|
||||||
|
max_length=None,
|
||||||
|
stride=0,
|
||||||
|
truncate_second_sequence_first=True,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
||||||
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
||||||
@@ -738,54 +746,40 @@ class PreTrainedTokenizer(object):
|
|||||||
If there are overflowing tokens, those will be added to the returned dictionary
|
If there are overflowing tokens, those will be added to the returned dictionary
|
||||||
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||||
|
truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence
|
||||||
|
will be truncated.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
information = {}
|
information = {}
|
||||||
|
|
||||||
if text_pair is None:
|
if text_pair is None:
|
||||||
n_added_tokens = self.num_added_tokens()
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
information = self.prepare_for_model(sequence_tokens, max_length, stride)
|
||||||
if max_length:
|
|
||||||
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens - stride:]
|
|
||||||
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
|
||||||
sequence = self.add_special_tokens_single_sequence(sequence_tokens)
|
|
||||||
else:
|
else:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
|
||||||
if max_length:
|
if max_length:
|
||||||
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
|
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
|
||||||
sequence_tokens = sequence_tokens[:max_length]
|
sequence_tokens = sequence_tokens[:max_length]
|
||||||
sequence = sequence_tokens
|
information["sequence"] = sequence_tokens
|
||||||
|
|
||||||
if output_mask:
|
if output_mask:
|
||||||
information["mask"] = [0] * len(sequence)
|
information["mask"] = [0] * len(information["sequence"])
|
||||||
|
|
||||||
information["sequence"] = sequence
|
|
||||||
else:
|
else:
|
||||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||||
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
|
|
||||||
n_added_tokens = self.num_added_tokens(pair=True)
|
|
||||||
|
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
if max_length:
|
information = self.prepare_pair_for_model(
|
||||||
if len(first_sentence_tokens) + n_added_tokens >= max_length:
|
|
||||||
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
|
||||||
else:
|
|
||||||
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
|
||||||
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
|
|
||||||
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
|
|
||||||
|
|
||||||
sequence = self.add_special_tokens_sequence_pair(
|
|
||||||
first_sentence_tokens,
|
first_sentence_tokens,
|
||||||
second_sentence_tokens
|
second_sentence_tokens,
|
||||||
|
max_length,
|
||||||
|
truncate_second_sequence_first,
|
||||||
|
stride
|
||||||
)
|
)
|
||||||
|
|
||||||
if output_mask:
|
if output_mask:
|
||||||
information["mask"] = self.create_mask_from_sequences(text, text_pair)
|
information["mask"] = self.create_mask_from_sequences(text, text_pair)
|
||||||
|
|
||||||
information["sequence"] = sequence
|
|
||||||
else:
|
else:
|
||||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
sequence = first_sentence_tokens + second_sentence_tokens
|
sequence = first_sentence_tokens + second_sentence_tokens
|
||||||
@@ -800,6 +794,39 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
return information
|
return information
|
||||||
|
|
||||||
|
def prepare_for_model(self, ids, max_length=None, stride=0):
|
||||||
|
information = {}
|
||||||
|
n_added_tokens = self.num_added_tokens()
|
||||||
|
if max_length:
|
||||||
|
information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
|
||||||
|
ids = ids[:max_length - n_added_tokens]
|
||||||
|
information["sequence"] = self.add_special_tokens_single_sequence(ids)
|
||||||
|
|
||||||
|
return information
|
||||||
|
|
||||||
|
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0):
|
||||||
|
f_len, s_len = len(ids_0), len(ids_1)
|
||||||
|
n_added_tokens = self.num_added_tokens(pair=True)
|
||||||
|
information = {}
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
if len(ids_0) + n_added_tokens >= max_length:
|
||||||
|
logger.warning(
|
||||||
|
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||||
|
else:
|
||||||
|
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
||||||
|
if truncate_second_sequence_first:
|
||||||
|
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
|
||||||
|
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
|
||||||
|
else:
|
||||||
|
information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
|
||||||
|
ids_0 = ids_0[:max_length - s_len - n_added_tokens]
|
||||||
|
|
||||||
|
sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
|
||||||
|
information["sequence"] = sequence
|
||||||
|
|
||||||
|
return information
|
||||||
|
|
||||||
def create_mask_from_sequences(self, sequence_0, sequence_1):
|
def create_mask_from_sequences(self, sequence_0, sequence_1):
|
||||||
logger.warning("This tokenizer does not make use of special tokens.")
|
logger.warning("This tokenizer does not make use of special tokens.")
|
||||||
return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))
|
return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))
|
||||||
|
|||||||
Reference in New Issue
Block a user