This commit is contained in:
thomwolf
2019-10-04 17:59:44 -04:00
parent 6c1d0bc066
commit 78ef1a9930
4 changed files with 29 additions and 24 deletions

View File

@@ -336,7 +336,6 @@ def convert_examples_to_features(
text_b, text_b,
add_special_tokens=True, add_special_tokens=True,
max_length=max_length, max_length=max_length,
truncate_both_sequences=True
) )
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0: if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
logger.info('Attention! you are cropping tokens (swag task is ok). ' logger.info('Attention! you are cropping tokens (swag task is ok). '

View File

@@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer,
example.text_b, example.text_b,
add_special_tokens=True, add_special_tokens=True,
max_length=max_length, max_length=max_length,
truncate_first_sequence=True # We're truncating the first sequence in priority
) )
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

View File

@@ -249,10 +249,10 @@ class CommonTestCases:
) )
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
stride=stride, truncate_first_sequence=False) stride=stride, truncation_strategy='only_second')
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
add_special_tokens=True, stride=stride, add_special_tokens=True, stride=stride,
truncate_first_sequence=True) truncation_strategy='only_first')
truncated_sequence = information["input_ids"] truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"] overflowing_tokens = information["overflowing_tokens"]

View File

@@ -692,8 +692,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False, add_special_tokens=False,
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncation_strategy='longest_first',
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
@@ -714,8 +713,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncation_strategy: string selected in the following options:
will be truncated. - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
@@ -725,8 +728,7 @@ class PreTrainedTokenizer(object):
max_length=max_length, max_length=max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncation_strategy=truncation_strategy,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors, return_tensors=return_tensors,
**kwargs) **kwargs)
@@ -738,8 +740,7 @@ class PreTrainedTokenizer(object):
add_special_tokens=False, add_special_tokens=False,
max_length=None, max_length=None,
stride=0, stride=0,
truncate_first_sequence=True, truncation_strategy='longest_first',
truncate_both_sequences=False,
return_tensors=None, return_tensors=None,
**kwargs): **kwargs):
""" """
@@ -759,8 +760,12 @@ class PreTrainedTokenizer(object):
If there are overflowing tokens, those will be added to the returned dictionary If there are overflowing tokens, those will be added to the returned dictionary
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens. from the main sequence returned. The value of this argument defines the number of additional tokens.
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence truncation_strategy: string selected in the following options:
will be truncated. - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
starting from the longest one at each token (when there is a pair of input sequences)
- 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
@@ -784,8 +789,7 @@ class PreTrainedTokenizer(object):
max_length=max_length, max_length=max_length,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
stride=stride, stride=stride,
truncate_first_sequence=truncate_first_sequence, truncation_strategy=truncation_strategy,
truncate_both_sequences=truncate_both_sequences,
return_tensors=return_tensors) return_tensors=return_tensors)
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
@@ -812,9 +816,6 @@ class PreTrainedTokenizer(object):
- 'only_first': Only truncate the first sequence - 'only_first': Only truncate the first sequence
- 'only_second': Only truncate the second sequence - 'only_second': Only truncate the second sequence
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
truncate_first_sequence: if set to `True` and an optional second list of input ids is provided,
alongside a specified `max_length`, will truncate the first sequence if the total size is superior
than the specified `max_length`. If set to `False`, will truncate the second sequence instead.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
@@ -844,7 +845,8 @@ class PreTrainedTokenizer(object):
if max_length and total_len > max_length: if max_length and total_len > max_length:
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
num_tokens_to_remove=total_len-max_length, num_tokens_to_remove=total_len-max_length,
truncation_strategy=truncation_strategy) truncation_strategy=truncation_strategy,
stride=stride)
encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["overflowing_tokens"] = overflowing_tokens
encoded_inputs["num_truncated_tokens"] = total_len - max_length encoded_inputs["num_truncated_tokens"] = total_len - max_length
@@ -875,7 +877,7 @@ class PreTrainedTokenizer(object):
return encoded_inputs return encoded_inputs
def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'): def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
"""Truncates a sequence pair in place to the maximum length. """Truncates a sequence pair in place to the maximum length.
truncation_strategy: string selected in the following options: truncation_strategy: string selected in the following options:
- 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
@@ -892,17 +894,22 @@ class PreTrainedTokenizer(object):
overflowing_tokens = [] overflowing_tokens = []
for _ in range(num_tokens_to_remove): for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids): if pair_ids is None or len(ids) > len(pair_ids):
overflowing_tokens.append(ids[-1]) overflowing_tokens = [ids[-1]] + overflowing_tokens
ids = ids[:-1] ids = ids[:-1]
else: else:
pair_ids = pair_ids[:-1] pair_ids = pair_ids[:-1]
window_len = min(len(ids), stride)
if window_len > 0:
overflowing_tokens = ids[-window_len:] + overflowing_tokens
elif truncation_strategy == 'only_first': elif truncation_strategy == 'only_first':
assert len(ids) > num_tokens_to_remove assert len(ids) > num_tokens_to_remove
overflowing_tokens = ids[-num_tokens_to_remove:] window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove] ids = ids[:-num_tokens_to_remove]
elif truncation_strategy == 'only_second': elif truncation_strategy == 'only_second':
assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
overflowing_tokens = pair_ids[-num_tokens_to_remove:] window_len = min(len(pair_ids), stride + num_tokens_to_remove)
overflowing_tokens = pair_ids[-window_len:]
pair_ids = pair_ids[:-num_tokens_to_remove] pair_ids = pair_ids[:-num_tokens_to_remove]
elif truncation_strategy == 'do_not_truncate': elif truncation_strategy == 'do_not_truncate':
raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.") raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")