From ebc36108dc1c20985905c79f7d6a00f57f3cd3ae Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Mon, 22 Jun 2020 17:25:43 +0200 Subject: [PATCH] [tokenizers] Fix #5081 and improve backward compatibility (#5125) * fix #5081 and improve backward compatibility (slightly) * add nlp to setup.cfg - style and quality * align default to previous default * remove test that doesn't generalize --- src/transformers/tokenization_utils.py | 10 ----- src/transformers/tokenization_utils_base.py | 45 +++++++++++++++++++++ tests/test_tokenization_common.py | 23 ----------- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 932ccef0f7..8b17e6325b 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -755,16 +755,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): def decode( self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True ) -> str: - """ - Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary - with options to remove special tokens and clean up tokenization spaces. - Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. - - Args: - token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. - skip_special_tokens: if set to True, will replace special tokens. - clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. - """ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # To avoid mixing byte-level and unicode for byte-level BPT diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 28e8354c88..0f14aad8e4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1774,6 +1774,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]: return [self.decode(seq, **kwargs) for seq in sequences] + def decode( + self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True + ) -> str: + """ + Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary + with options to remove special tokens and clean up tokenization spaces. + Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``. + + Args: + token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods. + skip_special_tokens: if set to True, will replace special tokens. + clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces. + """ + raise NotImplementedError + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + assert already_has_special_tokens and token_ids_1 is None, ( + "You cannot use ``already_has_special_tokens=False`` with this tokenizer. " + "Please use a slow (full python) tokenizer to activate this argument." + "Or set `return_special_token_mask=True` when calling the encoding method " + "to get the special tokens mask in any tokenizer. " + ) + + all_special_ids = self.all_special_ids # cache the property + + special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0] + + return special_tokens_mask + @staticmethod def clean_up_tokenization(out_string: str) -> str: """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms. diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index c35149852f..f85484f96b 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -672,29 +672,6 @@ class TokenizerTesterMixin: filtered_sequence = [x for x in filtered_sequence if x is not None] self.assertEqual(encoded_sequence, filtered_sequence) - def test_special_tokens_mask_already_has_special_tokens(self): - tokenizers = self.get_tokenizers(do_lower_case=False) - for tokenizer in tokenizers: - if not hasattr(tokenizer, "get_special_tokens_mask") or tokenizer.get_special_tokens_mask( - [0, 1, 2, 3] - ) == [0, 0, 0, 0]: - continue - with self.subTest(f"{tokenizer.__class__.__name__}"): - sequence_0 = "Encode this." - if ( - tokenizer.cls_token_id == tokenizer.unk_token_id - and tokenizer.cls_token_id == tokenizer.unk_token_id - ): - tokenizer.add_special_tokens({"cls_token": "", "sep_token": ""}) - encoded_sequence_dict = tokenizer.encode_plus( - sequence_0, add_special_tokens=True, return_special_tokens_mask=True - ) - # encoded_sequence_w_special = encoded_sequence_dict["input_ids"] - special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"] - min_val = min(special_tokens_mask_orig) - max_val = max(special_tokens_mask_orig) - self.assertNotEqual(min_val, max_val) - def test_right_and_left_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: