From 78ef1a99306213599d4eab8ae48f17a81d1ee2b8 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 4 Oct 2019 17:59:44 -0400 Subject: [PATCH] fixes --- examples/utils_multiple_choice.py | 1 - transformers/data/processors/glue.py | 1 - .../tests/tokenization_tests_commons.py | 4 +- transformers/tokenization_utils.py | 47 +++++++++++-------- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/examples/utils_multiple_choice.py b/examples/utils_multiple_choice.py index a7fc1b1222..a131a63924 100644 --- a/examples/utils_multiple_choice.py +++ b/examples/utils_multiple_choice.py @@ -336,7 +336,6 @@ def convert_examples_to_features( text_b, add_special_tokens=True, max_length=max_length, - truncate_both_sequences=True ) if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0: logger.info('Attention! you are cropping tokens (swag task is ok). ' diff --git a/transformers/data/processors/glue.py b/transformers/data/processors/glue.py index 61bca8c11b..741569ea30 100644 --- a/transformers/data/processors/glue.py +++ b/transformers/data/processors/glue.py @@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer, example.text_b, add_special_tokens=True, max_length=max_length, - truncate_first_sequence=True # We're truncating the first sequence in priority ) input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index b8f9295633..b2801d5f41 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -249,10 +249,10 @@ class CommonTestCases: ) information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, - stride=stride, truncate_first_sequence=False) + stride=stride, truncation_strategy='only_second') information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True, stride=stride, - truncate_first_sequence=True) + truncation_strategy='only_first') truncated_sequence = information["input_ids"] overflowing_tokens = information["overflowing_tokens"] diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index ce5811b96f..c7b55f3a9c 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -692,8 +692,7 @@ class PreTrainedTokenizer(object): add_special_tokens=False, max_length=None, stride=0, - truncate_first_sequence=True, - truncate_both_sequences=False, + truncation_strategy='longest_first', return_tensors=None, **kwargs): """ @@ -714,8 +713,12 @@ class PreTrainedTokenizer(object): If there are overflowing tokens, those will be added to the returned dictionary stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defines the number of additional tokens. - truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence - will be truncated. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. **kwargs: passed to the `self.tokenize()` method @@ -725,8 +728,7 @@ class PreTrainedTokenizer(object): max_length=max_length, add_special_tokens=add_special_tokens, stride=stride, - truncate_first_sequence=truncate_first_sequence, - truncate_both_sequences=truncate_both_sequences, + truncation_strategy=truncation_strategy, return_tensors=return_tensors, **kwargs) @@ -738,8 +740,7 @@ class PreTrainedTokenizer(object): add_special_tokens=False, max_length=None, stride=0, - truncate_first_sequence=True, - truncate_both_sequences=False, + truncation_strategy='longest_first', return_tensors=None, **kwargs): """ @@ -759,8 +760,12 @@ class PreTrainedTokenizer(object): If there are overflowing tokens, those will be added to the returned dictionary stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens from the main sequence returned. The value of this argument defines the number of additional tokens. - truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence - will be truncated. + truncation_strategy: string selected in the following options: + - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences) + - 'only_first': Only truncate the first sequence + - 'only_second': Only truncate the second sequence + - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. **kwargs: passed to the `self.tokenize()` method @@ -784,8 +789,7 @@ class PreTrainedTokenizer(object): max_length=max_length, add_special_tokens=add_special_tokens, stride=stride, - truncate_first_sequence=truncate_first_sequence, - truncate_both_sequences=truncate_both_sequences, + truncation_strategy=truncation_strategy, return_tensors=return_tensors) def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0, @@ -812,9 +816,6 @@ class PreTrainedTokenizer(object): - 'only_first': Only truncate the first sequence - 'only_second': Only truncate the second sequence - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length) - truncate_first_sequence: if set to `True` and an optional second list of input ids is provided, - alongside a specified `max_length`, will truncate the first sequence if the total size is superior - than the specified `max_length`. If set to `False`, will truncate the second sequence instead. return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant or PyTorch torch.Tensor instead of a list of python integers. @@ -844,7 +845,8 @@ class PreTrainedTokenizer(object): if max_length and total_len > max_length: ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids, num_tokens_to_remove=total_len-max_length, - truncation_strategy=truncation_strategy) + truncation_strategy=truncation_strategy, + stride=stride) encoded_inputs["overflowing_tokens"] = overflowing_tokens encoded_inputs["num_truncated_tokens"] = total_len - max_length @@ -875,7 +877,7 @@ class PreTrainedTokenizer(object): return encoded_inputs - def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first'): + def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0): """Truncates a sequence pair in place to the maximum length. truncation_strategy: string selected in the following options: - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length @@ -892,17 +894,22 @@ class PreTrainedTokenizer(object): overflowing_tokens = [] for _ in range(num_tokens_to_remove): if pair_ids is None or len(ids) > len(pair_ids): - overflowing_tokens.append(ids[-1]) + overflowing_tokens = [ids[-1]] + overflowing_tokens ids = ids[:-1] else: pair_ids = pair_ids[:-1] + window_len = min(len(ids), stride) + if window_len > 0: + overflowing_tokens = ids[-window_len:] + overflowing_tokens elif truncation_strategy == 'only_first': assert len(ids) > num_tokens_to_remove - overflowing_tokens = ids[-num_tokens_to_remove:] + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] ids = ids[:-num_tokens_to_remove] elif truncation_strategy == 'only_second': assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove - overflowing_tokens = pair_ids[-num_tokens_to_remove:] + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] pair_ids = pair_ids[:-num_tokens_to_remove] elif truncation_strategy == 'do_not_truncate': raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")