change the way sentinel tokens can retrived (#20373)

* change the way sentinel tokens can retrived

* Fix line length for doc string

* Fix line length for doc string

* Add more stronger test for t5 tokenization

* Format file changes

* Make a stronger test for filtering sentinel tokens

* fix file format issues
This commit is contained in:
raghavanone
2022-11-23 20:05:44 +05:30
committed by GitHub
parent 81d82e4f78
commit 03ae1f060b
3 changed files with 48 additions and 11 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import re
import tempfile
import unittest
@@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
model_name="t5-base",
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
)
def test_get_sentinel_tokens(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
def test_get_sentinel_tokens_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))