update encode_plus - add truncation strategies
This commit is contained in:
@@ -538,15 +538,9 @@ class PreTrainedTokenizer(object):
|
||||
Returns:
|
||||
Number of tokens added to sequences
|
||||
"""
|
||||
|
||||
if pair:
|
||||
initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
|
||||
final_tokens_len = len(self.encode("This is a sequence", "This is another", add_special_tokens=True))
|
||||
else:
|
||||
initial_tokens_len = len(self.encode("This is a sequence"))
|
||||
final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True))
|
||||
|
||||
return final_tokens_len - initial_tokens_len
|
||||
token_ids_0 = []
|
||||
token_ids_1 = []
|
||||
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict):
|
||||
"""
|
||||
@@ -795,7 +789,7 @@ class PreTrainedTokenizer(object):
|
||||
return_tensors=return_tensors)
|
||||
|
||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
|
||||
truncate_first_sequence=True, truncate_both_sequences=True, return_tensors=None):
|
||||
truncation_strategy='longest_first', return_tensors=None):
|
||||
"""
|
||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates
|
||||
@@ -812,6 +806,12 @@ class PreTrainedTokenizer(object):
|
||||
to their model.
|
||||
stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
|
||||
list of inputs.
|
||||
truncation_strategy: string selected in the following options:
|
||||
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
|
||||
starting from the longest one at each token (when there is a pair of input sequences)
|
||||
- 'only_first': Only truncate the first sequence
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
|
||||
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
|
||||
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
|
||||
@@ -840,37 +840,17 @@ class PreTrainedTokenizer(object):
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
encoded_inputs = {}
|
||||
if max_length:
|
||||
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
|
||||
|
||||
if n_added_tokens + len_ids + len_pair_ids > max_length:
|
||||
if truncate_both_sequences:
|
||||
tokens_a, tokens_b = self._truncate_seq_pair(
|
||||
copy.deepcopy(ids),
|
||||
copy.deepcopy(pair_ids),
|
||||
max_length=max_length - n_added_tokens
|
||||
)
|
||||
truncated_tokens = ids[- (len_ids - len(tokens_a)):] + pair_ids[- (len_pair_ids - len(tokens_b)):]
|
||||
encoded_inputs["num_truncated_tokens"] = len(truncated_tokens)
|
||||
ids = tokens_a
|
||||
pair_ids = tokens_b
|
||||
elif pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
|
||||
logger.warning(
|
||||
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
|
||||
"This pair of sequences will not be truncated.")
|
||||
elif truncate_first_sequence or not pair:
|
||||
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
|
||||
ids = ids[:max_length - len_pair_ids - n_added_tokens]
|
||||
elif not truncate_first_sequence and pair:
|
||||
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
|
||||
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot truncate second sequence as it is not provided. No truncation.")
|
||||
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
|
||||
if max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len-max_length,
|
||||
truncation_strategy=truncation_strategy)
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
if add_special_tokens:
|
||||
sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) if pair else [0] * len(sequence)
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
@@ -895,39 +875,76 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'):
|
||||
"""Truncates a sequence pair in place to the maximum length.
|
||||
truncation_strategy: string selected in the following options:
|
||||
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
|
||||
starting from the longest one at each token (when there is a pair of input sequences).
|
||||
Overflowing tokens only contains overflow from the first sequence.
|
||||
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||
"""
|
||||
if num_tokens_to_remove <= 0:
|
||||
return ids, pair_ids, []
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
if truncation_strategy == 'longest_first':
|
||||
overflowing_tokens = []
|
||||
for _ in range(num_tokens_to_remove):
|
||||
if pair_ids is None or len(ids) > len(pair_ids):
|
||||
overflowing_tokens.append(ids[-1])
|
||||
ids = ids[:-1]
|
||||
else:
|
||||
pair_ids = pair_ids[:-1]
|
||||
elif truncation_strategy == 'only_first':
|
||||
assert len(ids) > num_tokens_to_remove
|
||||
overflowing_tokens = ids[-num_tokens_to_remove:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
elif truncation_strategy == 'only_second':
|
||||
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
|
||||
overflowing_tokens = pair_ids[-num_tokens_to_remove:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
elif truncation_strategy == 'do_not_truncate':
|
||||
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
|
||||
else:
|
||||
raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']")
|
||||
return (ids, pair_ids, overflowing_tokens)
|
||||
|
||||
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
|
||||
# length or only pop from context
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
return (tokens_a, tokens_b)
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
|
||||
logger.warning("This tokenizer does not make use of special tokens.")
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0) * [0]
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
|
||||
def add_special_tokens_single_sequence(self, token_ids):
|
||||
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
|
||||
return token_ids
|
||||
|
||||
def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
|
||||
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A RoBERTa sequence has the following format:
|
||||
single sequence: <s> X </s>
|
||||
pair of sequences: <s> A </s></s> B </s>
|
||||
"""
|
||||
logger.warning("This tokenizer does not make use of special tokens. Input is returned with no modification.")
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
return token_ids_0 + token_ids_1
|
||||
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
|
||||
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
||||
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0: list of ids (must not contain special tokens)
|
||||
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
||||
for sequence pairs
|
||||
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
||||
special tokens for the model
|
||||
|
||||
Returns:
|
||||
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
|
||||
"""
|
||||
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
""" Converts a single index or a sequence of indices (integers) in a token "
|
||||
|
||||
Reference in New Issue
Block a user