change the way sentinel tokens can retrived (#20373)
* change the way sentinel tokens can retrived * Fix line length for doc string * Fix line length for doc string * Add more stronger test for t5 tokenization * Format file changes * Make a stronger test for filtering sentinel tokens * fix file format issues
This commit is contained in:
@@ -79,11 +79,10 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||||
The token used for padding, for example when batching sequences of different lengths.
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
extra_ids (`int`, *optional*, defaults to 100):
|
extra_ids (`int`, *optional*, defaults to 100):
|
||||||
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are
|
||||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be
|
||||||
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids
|
||||||
like in T5 preprocessing see
|
method
|
||||||
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
|
||||||
additional_special_tokens (`List[str]`, *optional*):
|
additional_special_tokens (`List[str]`, *optional*):
|
||||||
Additional special tokens used by the tokenizer.
|
Additional special tokens used by the tokenizer.
|
||||||
sp_model_kwargs (`dict`, *optional*):
|
sp_model_kwargs (`dict`, *optional*):
|
||||||
@@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
return ([0] * len(token_ids_0)) + [1]
|
return ([0] * len(token_ids_0)) + [1]
|
||||||
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||||
|
|
||||||
|
def get_sentinel_tokens(self):
|
||||||
|
return list(
|
||||||
|
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sentinel_token_ids(self):
|
||||||
|
return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()]
|
||||||
|
|
||||||
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
||||||
"""Do not add eos again if user already added it."""
|
"""Do not add eos again if user already added it."""
|
||||||
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||||
The token used for padding, for example when batching sequences of different lengths.
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
extra_ids (`int`, *optional*, defaults to 100):
|
extra_ids (`int`, *optional*, defaults to 100):
|
||||||
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
|
||||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
"<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
|
||||||
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
|
||||||
like in T5 preprocessing see
|
|
||||||
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
|
||||||
additional_special_tokens (`List[str]`, *optional*):
|
additional_special_tokens (`List[str]`, *optional*):
|
||||||
Additional special tokens used by the tokenizer.
|
Additional special tokens used by the tokenizer.
|
||||||
"""
|
"""
|
||||||
@@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return len(token_ids_0 + eos) * [0]
|
return len(token_ids_0 + eos) * [0]
|
||||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||||
|
|
||||||
|
def get_sentinel_tokens(self):
|
||||||
|
return list(
|
||||||
|
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sentinel_token_ids(self):
|
||||||
|
return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
model_name="t5-base",
|
model_name="t5-base",
|
||||||
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
|
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_sentinel_tokens(self):
|
||||||
|
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
|
||||||
|
sentinel_tokens = tokenizer.get_sentinel_tokens()
|
||||||
|
self.assertEquals(len(sentinel_tokens), 10)
|
||||||
|
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
|
||||||
|
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
|
||||||
|
|
||||||
|
def test_get_sentinel_token_ids(self):
|
||||||
|
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
|
||||||
|
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
|
||||||
|
|
||||||
|
def test_get_sentinel_tokens_for_fasttokenizer(self):
|
||||||
|
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
||||||
|
sentinel_tokens = tokenizer.get_sentinel_tokens()
|
||||||
|
self.assertEquals(len(sentinel_tokens), 10)
|
||||||
|
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
|
||||||
|
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
|
||||||
|
|
||||||
|
def test_get_sentinel_token_ids_for_fasttokenizer(self):
|
||||||
|
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
||||||
|
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
|
||||||
|
|||||||
Reference in New Issue
Block a user