Merge remote-tracking branch 'origin/julien_multiple-choice' into encoding-qol

This commit is contained in:
thomwolf
2019-10-04 15:48:06 -04:00
5 changed files with 127 additions and 152 deletions

View File

@@ -693,14 +693,15 @@ class PreTrainedTokenizer(object):
raise NotImplementedError
def encode(self,
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
return_tensors=None,
**kwargs):
text,
text_pair=None,
add_special_tokens=False,
max_length=None,
stride=0,
truncate_first_sequence=True,
truncate_both_sequences=False,
return_tensors=None,
**kwargs):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
@@ -718,7 +719,7 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
@@ -731,6 +732,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors,
**kwargs)
@@ -743,6 +745,7 @@ class PreTrainedTokenizer(object):
max_length=None,
stride=0,
truncate_first_sequence=True,
truncate_both_sequences=False,
return_tensors=None,
**kwargs):
"""
@@ -761,7 +764,7 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defined the number of additional tokens.
from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
will be truncated.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
@@ -788,10 +791,11 @@ class PreTrainedTokenizer(object):
add_special_tokens=add_special_tokens,
stride=stride,
truncate_first_sequence=truncate_first_sequence,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
truncate_first_sequence=True, return_tensors=None):
truncate_first_sequence=True, truncate_both_sequences=True, return_tensors=None):
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
It adds special tokens, truncates
@@ -838,21 +842,31 @@ class PreTrainedTokenizer(object):
encoded_inputs = {}
if max_length:
n_added_tokens = self.num_added_tokens(pair=pair) if add_special_tokens else 0
if pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length. "
"This pair of sequences will be truncated with no regard to the special tokens")
else:
if n_added_tokens + len_ids + len_pair_ids > max_length:
if truncate_first_sequence or not pair:
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
ids = ids[:max_length - len_pair_ids - n_added_tokens]
elif not truncate_first_sequence and pair:
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
else:
logger.warning(
"Cannot truncate second sequence as it is not provided. No truncation.")
if n_added_tokens + len_ids + len_pair_ids > max_length:
if truncate_both_sequences:
tokens_a, tokens_b = self._truncate_seq_pair(
copy.deepcopy(ids),
copy.deepcopy(pair_ids),
max_length=max_length - n_added_tokens
)
truncated_tokens = ids[- (len_ids - len(tokens_a)):] + pair_ids[- (len_pair_ids - len(tokens_b)):]
encoded_inputs["num_truncated_tokens"] = len(truncated_tokens)
ids = tokens_a
pair_ids = tokens_b
elif pair and n_added_tokens + (len_pair_ids if truncate_first_sequence else len_ids) >= max_length:
logger.warning(
"You supplied a pair of sequence in which the sequence that will not be truncated is longer than the maximum specified length."
"This pair of sequences will not be truncated.")
elif truncate_first_sequence or not pair:
encoded_inputs["overflowing_tokens"] = ids[max_length - len_pair_ids - n_added_tokens - stride:]
ids = ids[:max_length - len_pair_ids - n_added_tokens]
elif not truncate_first_sequence and pair:
encoded_inputs["overflowing_tokens"] = pair_ids[max_length - len_ids - n_added_tokens - stride:]
pair_ids = pair_ids[:max_length - len_ids - n_added_tokens]
else:
logger.warning(
"Cannot truncate second sequence as it is not provided. No truncation.")
if add_special_tokens:
sequence = self.add_special_tokens_sequence_pair(ids, pair_ids) if pair else self.add_special_tokens_single_sequence(ids)
@@ -881,6 +895,25 @@ class PreTrainedTokenizer(object):
return encoded_inputs
def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
return (tokens_a, tokens_b)
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1):
logger.warning("This tokenizer does not make use of special tokens.")
return [0] * len(token_ids_0) + [1] * len(token_ids_1)