Allow add_tokens for ESM (#28535)
* Allow non-special tokens to be added * Add test, fix token adding code * Revert changes to id_to_token and token_to_id * Update the ESM tokenizer to be a bit more standardized * Update src/transformers/models/esm/tokenization_esm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -14,10 +14,9 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for ESM."""
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import AddedToken
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -91,11 +90,10 @@ class EsmTokenizer(PreTrainedTokenizer):
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return text.split()
|
||||
|
||||
def get_vocab_size(self, with_added_tokens=False):
|
||||
return len(self._id_to_token)
|
||||
|
||||
def get_vocab(self):
|
||||
return {token: i for i, token in enumerate(self.all_tokens)}
|
||||
base_vocab = self._token_to_id.copy()
|
||||
base_vocab.update(self.added_tokens_encoder)
|
||||
return base_vocab
|
||||
|
||||
def token_to_id(self, token: str) -> int:
|
||||
return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
|
||||
@@ -156,7 +154,4 @@ class EsmTokenizer(PreTrainedTokenizer):
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.get_vocab_size(with_added_tokens=False)
|
||||
|
||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
||||
return super()._add_tokens(new_tokens, special_tokens=True)
|
||||
return len(self.all_tokens)
|
||||
|
||||
@@ -87,3 +87,25 @@ class ESMTokenizationTest(unittest.TestCase):
|
||||
self.assertEqual(len(token_2), 1)
|
||||
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
|
||||
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)
|
||||
|
||||
def test_add_tokens(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
vocab_size = len(tokenizer)
|
||||
self.assertEqual(tokenizer.add_tokens(""), 0)
|
||||
self.assertEqual(tokenizer.add_tokens("testoken"), 1)
|
||||
self.assertEqual(tokenizer.add_tokens(["testoken1", "testtoken2"]), 2)
|
||||
self.assertEqual(len(tokenizer), vocab_size + 3)
|
||||
|
||||
self.assertEqual(tokenizer.add_special_tokens({}), 0)
|
||||
self.assertEqual(tokenizer.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
|
||||
self.assertRaises(AssertionError, tokenizer.add_special_tokens, {"additional_special_tokens": "<testtoken1>"})
|
||||
self.assertEqual(tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
|
||||
self.assertEqual(
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
|
||||
)
|
||||
self.assertIn("<testtoken3>", tokenizer.special_tokens_map["additional_special_tokens"])
|
||||
self.assertIsInstance(tokenizer.special_tokens_map["additional_special_tokens"], list)
|
||||
self.assertGreaterEqual(len(tokenizer.special_tokens_map["additional_special_tokens"]), 2)
|
||||
|
||||
self.assertEqual(len(tokenizer), vocab_size + 8)
|
||||
|
||||
Reference in New Issue
Block a user