Add more tests on tokenizers serialization - fix bugs (#5056)

* update tests for fast tokenizers + fix small bug in saving/loading

* better tests on serialization

* fixing serialization

* comment cleanup
This commit is contained in:
Thomas Wolf
2020-06-24 21:53:08 +02:00
committed by GitHub
parent 0148c262e7
commit 7ac9110711
4 changed files with 170 additions and 98 deletions

View File

@@ -20,10 +20,10 @@ import re
import shutil
import tempfile
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
if TYPE_CHECKING:
@@ -93,7 +93,7 @@ class TokenizerTesterMixin:
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
return output_txt, output_ids
def get_tokenizers(self, fast=True, **kwargs) -> PreTrainedTokenizer:
def get_tokenizers(self, fast=True, **kwargs) -> List[PreTrainedTokenizerBase]:
if fast and self.test_rust_tokenizer:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
return [self.get_tokenizer(**kwargs)]
@@ -101,7 +101,7 @@ class TokenizerTesterMixin:
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
raise NotImplementedError
# def get_input_output_texts(self) -> Tuple[str, str]:
@@ -156,28 +156,62 @@ class TokenizerTesterMixin:
def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizers = self.get_tokenizers(fast=False)
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
self.assertNotEqual(tokenizer.max_len, 42)
# Now let's start the test
tokenizers = self.get_tokenizers(fast=False, model_max_length=42)
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
sample_text = "He is very happy, UNwant\u00E9d,running"
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
tokenizer.save_pretrained(self.tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertEqual(tokenizer.model_max_length, 42)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, model_max_length=43)
shutil.rmtree(tmpdirname)
# Now let's start the test
tokenizers = self.get_tokenizers(model_max_length=42)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_text = " He is very happy, UNwant\u00E9d,running"
tokenizer.add_tokens(["bim", "bambam"])
additional_special_tokens = tokenizer.additional_special_tokens
additional_special_tokens.append("new_additional_special_token")
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
after_vocab = after_tokenizer.get_vocab()
self.assertListEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertIn("bim", after_vocab)
self.assertIn("bambam", after_vocab)
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
self.assertEqual(after_tokenizer.model_max_length, 42)
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
self.assertEqual(tokenizer.model_max_length, 43)
shutil.rmtree(tmpdirname)
def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizers = self.get_tokenizers()
@@ -265,7 +299,10 @@ class TokenizerTesterMixin:
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
self.assertEqual(vocab_size, all_size)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)