Add more tests on tokenizers serialization - fix bugs (#5056)
* update tests for fast tokenizers + fix small bug in saving/loading * better tests on serialization * fixing serialization * comment cleanup
This commit is contained in:
@@ -20,10 +20,10 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -93,7 +93,7 @@ class TokenizerTesterMixin:
|
||||
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
||||
return output_txt, output_ids
|
||||
|
||||
def get_tokenizers(self, fast=True, **kwargs) -> PreTrainedTokenizer:
|
||||
def get_tokenizers(self, fast=True, **kwargs) -> List[PreTrainedTokenizerBase]:
|
||||
if fast and self.test_rust_tokenizer:
|
||||
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
|
||||
return [self.get_tokenizer(**kwargs)]
|
||||
@@ -101,7 +101,7 @@ class TokenizerTesterMixin:
|
||||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
|
||||
raise NotImplementedError
|
||||
|
||||
# def get_input_output_texts(self) -> Tuple[str, str]:
|
||||
@@ -156,28 +156,62 @@ class TokenizerTesterMixin:
|
||||
|
||||
def test_save_and_load_tokenizer(self):
|
||||
# safety check on max_len default value so we are sure the test works
|
||||
tokenizers = self.get_tokenizers(fast=False)
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
self.assertNotEqual(tokenizer.max_len, 42)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizers = self.get_tokenizers(fast=False, model_max_length=42)
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
sample_text = "He is very happy, UNwant\u00E9d,running"
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_text = " He is very happy, UNwant\u00E9d,running"
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
|
||||
self.assertEqual(tokenizer.model_max_length, 42)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, model_max_length=43)
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
# Now let's start the test
|
||||
tokenizers = self.get_tokenizers(model_max_length=42)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
# Isolate this from the other tests because we save additional tokens/etc
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
sample_text = " He is very happy, UNwant\u00E9d,running"
|
||||
tokenizer.add_tokens(["bim", "bambam"])
|
||||
additional_special_tokens = tokenizer.additional_special_tokens
|
||||
additional_special_tokens.append("new_additional_special_token")
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
before_vocab = tokenizer.get_vocab()
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
|
||||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
after_vocab = after_tokenizer.get_vocab()
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
self.assertDictEqual(before_vocab, after_vocab)
|
||||
self.assertIn("bim", after_vocab)
|
||||
self.assertIn("bambam", after_vocab)
|
||||
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
|
||||
self.assertEqual(after_tokenizer.model_max_length, 42)
|
||||
|
||||
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
|
||||
self.assertEqual(tokenizer.model_max_length, 43)
|
||||
|
||||
shutil.rmtree(tmpdirname)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||
tokenizers = self.get_tokenizers()
|
||||
@@ -265,7 +299,10 @@ class TokenizerTesterMixin:
|
||||
all_size = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
self.assertEqual(vocab_size, all_size)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
|
||||
Reference in New Issue
Block a user