AutoTokenizer: infer the class from the tokenizer config if possible (#12208)

* AutoTokenizer: infer the class from the tokenizer config if possible

* Add tests

* Update src/transformers/models/auto/tokenization_auto.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-06-17 12:39:22 -04:00
committed by GitHub
parent 0daadc1919
commit adb70eda4d
3 changed files with 168 additions and 10 deletions

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import (
@@ -29,7 +29,7 @@ from transformers import (
RobertaTokenizerFast,
)
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config
from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
@@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.vocab_size, 30000)
self.assertEqual(tokenizer.unk_token, "[UNK]")
self.assertEqual(tokenizer.padding_side, "right")
def test_auto_tokenizer_from_local_folder(self):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(tokenizer2, tokenizer.__class__)
self.assertEqual(tokenizer2.vocab_size, 12)
def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased")
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
self.assertEqual(config, {"do_lower_case": False})
# This model does not have a tokenizer_config so we get back an empty dict.
config = get_tokenizer_config(SMALL_MODEL_IDENTIFIER)
self.assertDictEqual(config, {})
# A tokenizer saved with `save_pretrained` always creates a tokenizer config.
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
config = get_tokenizer_config(tmp_dir)
# Check the class of the tokenizer was properly saved (note that it always saves the slow class).
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
# Check other keys just to make sure the config was properly saved /reloaded.
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)