* fix #5081 and improve backward compatibility (slightly) * add nlp to setup.cfg - style and quality * align default to previous default * remove test that doesn't generalize
This commit is contained in:
@@ -755,16 +755,6 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
def decode(
|
def decode(
|
||||||
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
|
||||||
with options to remove special tokens and clean up tokenization spaces.
|
|
||||||
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
|
|
||||||
skip_special_tokens: if set to True, will replace special tokens.
|
|
||||||
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
|
|
||||||
"""
|
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
# To avoid mixing byte-level and unicode for byte-level BPT
|
# To avoid mixing byte-level and unicode for byte-level BPT
|
||||||
|
|||||||
@@ -1774,6 +1774,51 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
|
def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
|
||||||
return [self.decode(seq, **kwargs) for seq in sequences]
|
return [self.decode(seq, **kwargs) for seq in sequences]
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||||
|
with options to remove special tokens and clean up tokenization spaces.
|
||||||
|
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
|
||||||
|
skip_special_tokens: if set to True, will replace special tokens.
|
||||||
|
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_special_tokens_mask(
|
||||||
|
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
|
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0: list of ids (must not contain special tokens)
|
||||||
|
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
||||||
|
for sequence pairs
|
||||||
|
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
||||||
|
special tokens for the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
|
"""
|
||||||
|
assert already_has_special_tokens and token_ids_1 is None, (
|
||||||
|
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
|
||||||
|
"Please use a slow (full python) tokenizer to activate this argument."
|
||||||
|
"Or set `return_special_token_mask=True` when calling the encoding method "
|
||||||
|
"to get the special tokens mask in any tokenizer. "
|
||||||
|
)
|
||||||
|
|
||||||
|
all_special_ids = self.all_special_ids # cache the property
|
||||||
|
|
||||||
|
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
|
||||||
|
|
||||||
|
return special_tokens_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clean_up_tokenization(out_string: str) -> str:
|
def clean_up_tokenization(out_string: str) -> str:
|
||||||
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
|
""" Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
|
||||||
|
|||||||
@@ -672,29 +672,6 @@ class TokenizerTesterMixin:
|
|||||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||||
|
|
||||||
def test_special_tokens_mask_already_has_special_tokens(self):
|
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
|
||||||
for tokenizer in tokenizers:
|
|
||||||
if not hasattr(tokenizer, "get_special_tokens_mask") or tokenizer.get_special_tokens_mask(
|
|
||||||
[0, 1, 2, 3]
|
|
||||||
) == [0, 0, 0, 0]:
|
|
||||||
continue
|
|
||||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
|
||||||
sequence_0 = "Encode this."
|
|
||||||
if (
|
|
||||||
tokenizer.cls_token_id == tokenizer.unk_token_id
|
|
||||||
and tokenizer.cls_token_id == tokenizer.unk_token_id
|
|
||||||
):
|
|
||||||
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
|
|
||||||
encoded_sequence_dict = tokenizer.encode_plus(
|
|
||||||
sequence_0, add_special_tokens=True, return_special_tokens_mask=True
|
|
||||||
)
|
|
||||||
# encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
|
||||||
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
|
||||||
min_val = min(special_tokens_mask_orig)
|
|
||||||
max_val = max(special_tokens_mask_orig)
|
|
||||||
self.assertNotEqual(min_val, max_val)
|
|
||||||
|
|
||||||
def test_right_and_left_padding(self):
|
def test_right_and_left_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
|||||||
Reference in New Issue
Block a user