From 03ae1f060bbb8cfd8ba691385b35a7ae09adcf33 Mon Sep 17 00:00:00 2001 From: raghavanone <115454562+raghavanone@users.noreply.github.com> Date: Wed, 23 Nov 2022 20:05:44 +0530 Subject: [PATCH] change the way sentinel tokens can retrived (#20373) * change the way sentinel tokens can retrived * Fix line length for doc string * Fix line length for doc string * Add more stronger test for t5 tokenization * Format file changes * Make a stronger test for filtering sentinel tokens * fix file format issues --- src/transformers/models/t5/tokenization_t5.py | 19 ++++++++++----- .../models/t5/tokenization_t5_fast.py | 17 ++++++++++---- tests/models/t5/test_tokenization_t5.py | 23 +++++++++++++++++++ 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index 5d016ab7d8..44fc58251c 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -79,12 +79,11 @@ class T5Tokenizer(PreTrainedTokenizer): pad_token (`str`, *optional*, defaults to `""`): The token used for padding, for example when batching sequences of different lengths. extra_ids (`int`, *optional*, defaults to 100): - Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are - accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are - indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary - like in T5 preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). - additional_special_tokens (`List[str]`, *optional*): + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are + accessible as "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be + retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids + method + additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. sp_model_kwargs (`dict`, *optional*): Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for @@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer): return ([0] * len(token_ids_0)) + [1] return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search("", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()] + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: """Do not add eos again if user already added it.""" if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index 41ad306b74..6fcb34043d 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -16,6 +16,7 @@ import os +import re import warnings from shutil import copyfile from typing import List, Optional, Tuple @@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast): pad_token (`str`, *optional*, defaults to `""`): The token used for padding, for example when batching sequences of different lengths. extra_ids (`int`, *optional*, defaults to 100): - Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are - accessible as "" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are - indexed from the end of the vocabulary up to beginning ("" is the last token in the vocabulary - like in T5 preprocessing see - [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)). + Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as + "" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by + calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method additional_special_tokens (`List[str]`, *optional*): Additional special tokens used by the tokenizer. """ @@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast): if token_ids_1 is None: return len(token_ids_0 + eos) * [0] return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search("", x)) is not None, self.additional_special_tokens)) + ) + + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] diff --git a/tests/models/t5/test_tokenization_t5.py b/tests/models/t5/test_tokenization_t5.py index 28d85c77c9..4a8ffb1ced 100644 --- a/tests/models/t5/test_tokenization_t5.py +++ b/tests/models/t5/test_tokenization_t5.py @@ -14,6 +14,7 @@ # limitations under the License. import json import os +import re import tempfile import unittest @@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): model_name="t5-base", revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b", ) + + def test_get_sentinel_tokens(self): + tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10) + sentinel_tokens = tokenizer.get_sentinel_tokens() + self.assertEquals(len(sentinel_tokens), 10) + self.assertListEqual(sorted(sentinel_tokens), sorted([f"" for i in range(0, 10)])) + self.assertTrue([re.search("", token) is not None for token in sentinel_tokens]) + + def test_get_sentinel_token_ids(self): + tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10) + self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)])) + + def test_get_sentinel_tokens_for_fasttokenizer(self): + tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10) + sentinel_tokens = tokenizer.get_sentinel_tokens() + self.assertEquals(len(sentinel_tokens), 10) + self.assertListEqual(sorted(sentinel_tokens), sorted([f"" for i in range(0, 10)])) + self.assertTrue([re.search("", token) is not None for token in sentinel_tokens]) + + def test_get_sentinel_token_ids_for_fasttokenizer(self): + tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10) + self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))