From 17ade127b9d540d6029cace38eec61a267ea604d Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 3 Jul 2020 10:51:21 -0400 Subject: [PATCH] Exposing prepare_for_model for both slow & fast tokenizers (#5479) * Exposing prepare_for_model for both slow & fast tokenizers * Update method signature * The traditional style commit * Hide the warnings behind the verbose flag * update default truncation strategy and prepare_for_model * fix tests and prepare_for_models methods Co-authored-by: Thomas Wolf --- src/transformers/tokenization_utils.py | 195 +--------------- src/transformers/tokenization_utils_base.py | 238 +++++++++++++++++++- tests/test_tokenization_common.py | 50 +++- tests/test_tokenization_fast.py | 7 + 4 files changed, 285 insertions(+), 205 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 15fb58bff0..16dbec4db9 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -454,12 +454,12 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): first_ids = get_input_ids(text) second_ids = get_input_ids(text_pair) if text_pair is not None else None - return self._prepare_for_model( + return self.prepare_for_model( first_ids, pair_ids=second_ids, add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, + padding=padding_strategy.value, + truncation=truncation_strategy.value, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, @@ -584,7 +584,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): batch_outputs = {} for first_ids, second_ids in batch_ids_pairs: - outputs = self._prepare_for_model( + outputs = self.prepare_for_model( first_ids, second_ids, add_special_tokens=add_special_tokens, @@ -620,109 +620,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): return batch_outputs - @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) - def _prepare_for_model( - self, - ids: List[int], - pair_ids: Optional[List[int]] = None, - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[str] = None, - prepend_batch_axis: bool = False, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_length: bool = False, - verbose: bool = True, - ) -> BatchEncoding: - """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. - It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and - manages a moving window (with user defined stride) for overflowing tokens - - Args: - ids: list of tokenized input ids. Can be obtained from a string by chaining the - `tokenize` and `convert_tokens_to_ids` methods. - pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the - `tokenize` and `convert_tokens_to_ids` methods. - """ - pair = bool(pair_ids is not None) - len_ids = len(ids) - len_pair_ids = len(pair_ids) if pair else 0 - - # Load from model defaults - if return_token_type_ids is None: - return_token_type_ids = "token_type_ids" in self.model_input_names - if return_attention_mask is None: - return_attention_mask = "attention_mask" in self.model_input_names - - encoded_inputs = {} - - # Compute the total size of the returned encodings - total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) - - # Truncation: Handle max sequence length - if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: - ids, pair_ids, overflowing_tokens = self.truncate_sequences( - ids, - pair_ids=pair_ids, - num_tokens_to_remove=total_len - max_length, - truncation_strategy=truncation_strategy, - stride=stride, - ) - if return_overflowing_tokens: - encoded_inputs["overflowing_tokens"] = overflowing_tokens - encoded_inputs["num_truncated_tokens"] = total_len - max_length - - # Add special tokens - if add_special_tokens: - sequence = self.build_inputs_with_special_tokens(ids, pair_ids) - token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) - else: - sequence = ids + pair_ids if pair else ids - token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) - - # Build output dictionnary - encoded_inputs["input_ids"] = sequence - if return_token_type_ids: - encoded_inputs["token_type_ids"] = token_type_ids - if return_special_tokens_mask: - if add_special_tokens: - encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) - else: - encoded_inputs["special_tokens_mask"] = [0] * len(sequence) - - # Check lengths - if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: - logger.warning( - "Token indices sequence length is longer than the specified maximum sequence length " - "for this model ({} > {}). Running this sequence through the model will result in " - "indexing errors".format(len(ids), self.model_max_length) - ) - - # Padding - if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: - encoded_inputs = self.pad( - encoded_inputs, - max_length=max_length, - padding=padding_strategy.value, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - - if return_length: - encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - - batch_outputs = BatchEncoding( - encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis - ) - - return batch_outputs - def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict): """ Performs any necessary transformations before tokenization. @@ -731,90 +628,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): """ return (text, kwargs) - def truncate_sequences( - self, - ids: List[int], - pair_ids: Optional[List[int]] = None, - num_tokens_to_remove: int = 0, - truncation_strategy: Union[str, TruncationStrategy] = "only_first", - stride: int = 0, - ) -> Tuple[List[int], List[int], List[int]]: - """ Truncates a sequence pair in place to the maximum length. - - Args: - ids: list of tokenized input ids. Can be obtained from a string by chaining the - `tokenize` and `convert_tokens_to_ids` methods. - pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the - `tokenize` and `convert_tokens_to_ids` methods. - num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``): - number of tokens to remove using the truncation strategy - truncation_strategy (:obj:`string`, `optional`, defaults to "only_first"): - String selected in the following options: - - - 'only_first' (default): Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. - - 'only_second': Only truncate the second sequence - - 'longest_first': Iteratively reduce the inputs sequence until the input is under max_length - starting from the longest one at each token (when there is a pair of input sequences). - Overflowing tokens only contains overflow from the first sequence. - - 'do_not_truncate' - stride (:obj:`int`, `optional`, defaults to ``0``): - If set to a number along with max_length, the overflowing tokens returned will contain some tokens - from the main sequence returned. The value of this argument defines the number of additional tokens. - """ - if num_tokens_to_remove <= 0: - return ids, pair_ids, [] - - if not isinstance(truncation_strategy, TruncationStrategy): - truncation_strategy = TruncationStrategy(truncation_strategy) - - overflowing_tokens = [] - if truncation_strategy == TruncationStrategy.LONGEST_FIRST: - for _ in range(num_tokens_to_remove): - if pair_ids is None or len(ids) > len(pair_ids): - ids = ids[:-1] - else: - pair_ids = pair_ids[:-1] - elif truncation_strategy == TruncationStrategy.ONLY_FIRST: - if len(ids) > num_tokens_to_remove: - window_len = min(len(ids), stride + num_tokens_to_remove) - overflowing_tokens = ids[-window_len:] - ids = ids[:-num_tokens_to_remove] - else: - logger.error( - f"We need to remove {num_tokens_to_remove} to truncate the input" - f"but the first sequence has a length {len(ids)}. " - f"Please select another truncation strategy than {truncation_strategy}, " - f"for instance 'longest_first' or 'only_second'." - ) - elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: - if len(pair_ids) > num_tokens_to_remove: - window_len = min(len(pair_ids), stride + num_tokens_to_remove) - overflowing_tokens = pair_ids[-window_len:] - pair_ids = pair_ids[:-num_tokens_to_remove] - else: - logger.error( - f"We need to remove {num_tokens_to_remove} to truncate the input" - f"but the second sequence has a length {len(pair_ids)}. " - f"Please select another truncation strategy than {truncation_strategy}, " - f"for instance 'longest_first' or 'only_first'." - ) - - return (ids, pair_ids, overflowing_tokens) - - def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]: - if token_ids_1 is None: - return len(token_ids_0) * [0] - return [0] * len(token_ids_0) + [1] * len(token_ids_1) - - def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. This implementation does not add special tokens. - """ - if token_ids_1 is None: - return token_ids_0 - return token_ids_0 + token_ids_1 - def get_special_tokens_mask( self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False ) -> List[int]: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 073d84a948..6a1219a4d3 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -945,9 +945,9 @@ ENCODE_KWARGS_DOCSTRING = r""" `truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`): Activate and control truncation. Accepts the following values: - * `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided, + * `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided, + * `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided, * `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided, - * `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided, * `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size) `max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`): Control the length for padding/truncation. Accepts the following values @@ -1446,10 +1446,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): logger.warning( "Truncation was not explicitely activated but `max_length` is provided a specific value, " "please use `truncation=True` to explicitely truncate examples to max length. " - "Defaulting to 'only_first' truncation strategy. " - "If you encode pairs of sequences (GLUE-style) with the tokenizer you may want to check this is the right behavior." + "Defaulting to 'longest_first' truncation strategy. " + "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy " + "more precisely by providing a specific strategy to `truncation`." ) - truncation = "only_first" + truncation = "longest_first" # Get padding strategy if padding is False and old_pad_to_max_length: @@ -1469,7 +1470,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): elif padding is not False: if padding is True: padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch - else: + elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) else: padding_strategy = PaddingStrategy.DO_NOT_PAD @@ -1492,9 +1493,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): elif truncation is not False: if truncation is True: truncation_strategy = ( - TruncationStrategy.ONLY_FIRST - ) # Default to truncate the first sequences in pairs of inputs - else: + TruncationStrategy.LONGEST_FIRST + ) # Default to truncate the longest sequences in pairs of inputs + elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE @@ -1960,6 +1961,225 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): return BatchEncoding(batch_outputs, tensor_type=return_tensors) + def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]: + if token_ids_1 is None: + return len(token_ids_0) * [0] + return [0] * len(token_ids_0) + [1] * len(token_ids_1) + + def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. This implementation does not add special tokens. + """ + if token_ids_1 is None: + return token_ids_0 + return token_ids_0 + token_ids_1 + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def prepare_for_model( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. + It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + """ + + if "return_lengths" in kwargs: + if verbose: + warnings.warn( + "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. " + "Please use `return_length` instead.", + FutureWarning, + ) + return_length = kwargs["return_lengths"] + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + # Compute the total size of the returned encodings + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ids, pair_ids, overflowing_tokens = self.truncate_sequences( + ids, + pair_ids=pair_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else []) + + # Build output dictionnary + encoded_inputs["input_ids"] = sequence + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + # Check lengths + if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.model_max_length) + ) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + def truncate_sequences( + self, + ids: List[int], + pair_ids: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ Truncates a sequence pair in place to the maximum length. + + Args: + ids: list of tokenized input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the + `tokenize` and `convert_tokens_to_ids` methods. + num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``): + number of tokens to remove using the truncation strategy + truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"): + String selected in the following options: + + - 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length + starting from the longest one at each token (when there is a pair of input sequences). + Overflowing tokens only contains overflow from the first sequence. + - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove. + - 'only_second': Only truncate the second sequence + - 'do_not_truncate' + stride (:obj:`int`, `optional`, defaults to ``0``): + If set to a number along with max_length, the overflowing tokens returned will contain some tokens + from the main sequence returned. The value of this argument defines the number of additional tokens. + """ + if num_tokens_to_remove <= 0: + return ids, pair_ids, [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + if truncation_strategy == TruncationStrategy.LONGEST_FIRST: + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + if not overflowing_tokens: + window_len = min(len(ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(ids[-window_len:]) + ids = ids[:-1] + else: + if not overflowing_tokens: + window_len = min(len(pair_ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(pair_ids[-window_len:]) + pair_ids = pair_ids[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_FIRST: + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + ids = ids[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input" + f"but the first sequence has a length {len(ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + f"for instance 'longest_first' or 'only_second'." + ) + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input" + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + f"for instance 'longest_first' or 'only_first'." + ) + + return (ids, pair_ids, overflowing_tokens) + def _pad( self, encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index d14fec512b..d23de55efa 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -508,9 +508,7 @@ class TokenizerTesterMixin: self.assertEqual(len(truncated_sequence), total_length - 2) self.assertEqual(truncated_sequence, sequence[:-2]) - self.assertEqual( - len(overflowing_tokens), 0 - ) # No overflowing tokens when using 'longest' in python tokenizers + self.assertEqual(len(overflowing_tokens), 2 + stride) def test_maximum_encoding_length_pair_input(self): tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100) @@ -634,7 +632,39 @@ class TokenizerTesterMixin: self.assertEqual(truncated_sequence, truncated_longest_sequence) self.assertEqual( - len(overflowing_tokens), 0 + len(overflowing_tokens), 2 + stride + ) # No overflowing tokens when using 'longest' in python tokenizers + + information = tokenizer.encode_plus( + seq_0, + seq_1, + max_length=len(sequence) - 2, + add_special_tokens=False, + stride=stride, + truncation=True, + return_overflowing_tokens=True, + # add_prefix_space=False, + ) + # Overflowing tokens are handled quite differently in slow and fast tokenizers + if isinstance(tokenizer, PreTrainedTokenizerFast): + truncated_sequence = information["input_ids"][0] + overflowing_tokens = information["input_ids"][1] + self.assertEqual(len(information["input_ids"]), 2) + + self.assertEqual(len(truncated_sequence), len(sequence) - 2) + self.assertEqual(truncated_sequence, truncated_longest_sequence) + + self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest)) + self.assertEqual(overflowing_tokens, overflow_longest_sequence) + else: + truncated_sequence = information["input_ids"] + overflowing_tokens = information["overflowing_tokens"] + + self.assertEqual(len(truncated_sequence), len(sequence) - 2) + self.assertEqual(truncated_sequence, truncated_longest_sequence) + + self.assertEqual( + len(overflowing_tokens), 2 + stride ) # No overflowing tokens when using 'longest' in python tokenizers information_first_truncated = tokenizer.encode_plus( @@ -643,7 +673,7 @@ class TokenizerTesterMixin: max_length=len(sequence) - 2, add_special_tokens=False, stride=stride, - truncation=True, + truncation="only_first", return_overflowing_tokens=True, # add_prefix_space=False, ) @@ -1293,6 +1323,16 @@ class TokenizerTesterMixin: for key in output.keys(): self.assertEqual(output[key], output_sequence[key]) + def test_prepare_for_model(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + string_sequence = "Testing the prepare_for_model method." + ids = tokenizer.encode(string_sequence, add_special_tokens=False) + input_dict = tokenizer.encode_plus(string_sequence) + prepared_input_dict = tokenizer.prepare_for_model(ids) + + self.assertEqual(input_dict, prepared_input_dict) + @require_torch @require_tf def test_batch_encode_plus_tensors(self): diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index e2de9acf03..0682df6729 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -90,6 +90,7 @@ class CommonFastTokenizerTest(unittest.TestCase): self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p) self.assert_padding(tokenizer_r, tokenizer_p) self.assert_create_token_type_ids(tokenizer_r, tokenizer_p) + self.assert_prepare_for_model(tokenizer_r, tokenizer_p) # TODO: enable for v3.0.0 # self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p) @@ -709,6 +710,12 @@ class CommonFastTokenizerTest(unittest.TestCase): for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]): self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add) + def assert_prepare_for_model(self, tokenizer_r, tokenizer_p): + string_sequence = "Asserting that both tokenizers are equal" + python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence)) + rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence)) + self.assertEqual(python_output, rust_output) + class WordPieceFastTokenizerTest(CommonFastTokenizerTest): """