[multiple-choice] Simplify and use tokenizer.encode_plus
This commit is contained in:
@@ -699,6 +699,7 @@ class PreTrainedTokenizer(object):
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncate_first_sequence=True,
|
||||
truncate_both_sequences=False,
|
||||
return_tensors=None,
|
||||
**kwargs):
|
||||
"""
|
||||
@@ -718,7 +719,7 @@ class PreTrainedTokenizer(object):
|
||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||
If there are overflowing tokens, those will be added to the returned dictionary
|
||||
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||
from the main sequence returned. The value of this argument defines the number of additional tokens.
|
||||
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
||||
will be truncated.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
@@ -731,6 +732,7 @@ class PreTrainedTokenizer(object):
|
||||
add_special_tokens=add_special_tokens,
|
||||
stride=stride,
|
||||
truncate_first_sequence=truncate_first_sequence,
|
||||
truncate_both_sequences=truncate_both_sequences,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs)
|
||||
|
||||
@@ -743,6 +745,7 @@ class PreTrainedTokenizer(object):
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncate_first_sequence=True,
|
||||
truncate_both_sequences=False,
|
||||
return_tensors=None,
|
||||
**kwargs):
|
||||
"""
|
||||
@@ -761,7 +764,7 @@ class PreTrainedTokenizer(object):
|
||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||
If there are overflowing tokens, those will be added to the returned dictionary
|
||||
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||
from the main sequence returned. The value of this argument defines the number of additional tokens.
|
||||
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
||||
will be truncated.
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
@@ -788,11 +791,12 @@ class PreTrainedTokenizer(object):
|
||||
add_special_tokens=add_special_tokens,
|
||||
stride=stride,
|
||||
truncate_first_sequence=truncate_first_sequence,
|
||||
truncate_both_sequences=truncate_both_sequences,
|
||||
return_tensors=return_tensors)
|
||||
|
||||
|
||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
|
||||
truncate_first_sequence=True, return_tensors=None):
|
||||
truncate_first_sequence=True, truncate_both_sequences=True, return_tensors=None):
|
||||
"""
|
||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates
|
||||
@@ -825,21 +829,30 @@ class PreTrainedTokenizer(object):
|
||||
encoded_inputs = {}
|
||||
if max_length:
|
||||
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
|
||||
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
|
||||
logger.warning(
|
||||
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
|
||||
"This pair of sequences will not be truncated.")
|
||||
else:
|
||||
if n_added_tokens + len_ids + len_pair_ids > max_length:
|
||||
if truncate_first_sequence or not pair:
|
||||
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
|
||||
ids = ids[:max_length - len_pair_ids - n_added_tokens]
|
||||
elif not truncate_first_sequence and pair:
|
||||
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
|
||||
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot truncate second sequence as it is not provided. No truncation.")
|
||||
|
||||
if n_added_tokens + len_ids + len_pair_ids > max_length:
|
||||
if truncate_both_sequences:
|
||||
tokens_a, tokens_b = self._truncate_seq_pair(
|
||||
copy.deepcopy(ids),
|
||||
copy.deepcopy(pair_ids),
|
||||
max_length=max_length - n_added_tokens
|
||||
)
|
||||
encoded_inputs["overflowing_tokens"] = ids[- (len_ids - len(tokens_a)):] + pair_ids[- (len_pair_ids - len(tokens_b)):]
|
||||
ids = tokens_a
|
||||
pair_ids = tokens_b
|
||||
elif pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
|
||||
logger.warning(
|
||||
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
|
||||
"This pair of sequences will not be truncated.")
|
||||
elif truncate_first_sequence or not pair:
|
||||
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
|
||||
ids = ids[:max_length - len_pair_ids - n_added_tokens]
|
||||
elif not truncate_first_sequence and pair:
|
||||
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
|
||||
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot truncate second sequence as it is not provided. No truncation.")
|
||||
|
||||
if add_special_tokens:
|
||||
sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
|
||||
@@ -862,6 +875,25 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
|
||||
"""Truncates a sequence pair in place to the maximum length."""
|
||||
|
||||
# This is a simple heuristic which will always truncate the longer sequence
|
||||
# one token at a time. This makes more sense than truncating an equal percent
|
||||
# of tokens from each, since if one sequence is very short then each token
|
||||
# that's truncated likely contains more information than a longer sequence.
|
||||
|
||||
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
|
||||
# length or only pop from context
|
||||
while True:
|
||||
total_length = len(tokens_a) + len(tokens_b)
|
||||
if total_length <= max_length:
|
||||
return (tokens_a, tokens_b)
|
||||
if len(tokens_a) > len(tokens_b):
|
||||
tokens_a.pop()
|
||||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
|
||||
logger.warning("This tokenizer does not make use of special tokens.")
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
|
||||
Reference in New Issue
Block a user