[AutoTokenizer] Allow creation of tokenizers by tokenizer type (#13668)

* up

* up
This commit is contained in:
Patrick von Platen
2021-09-22 00:29:38 +02:00
committed by GitHub
parent 2608944dc2
commit 8e908c8c74
5 changed files with 81 additions and 1 deletions

View File

@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import unittest
import pytest
from transformers import (
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -78,6 +82,39 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
self.assertEqual(tokenizer.vocab_size, 12)
def test_tokenizer_from_type(self):
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert", use_fast=False)
self.assertIsInstance(tokenizer, BertTokenizer)
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2", use_fast=False)
self.assertIsInstance(tokenizer, GPT2Tokenizer)
@require_tokenizers
def test_tokenizer_from_type_fast(self):
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert")
self.assertIsInstance(tokenizer, BertTokenizerFast)
with tempfile.TemporaryDirectory() as tmp_dir:
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2")
self.assertIsInstance(tokenizer, GPT2TokenizerFast)
def test_tokenizer_from_type_incorrect_name(self):
with pytest.raises(ValueError):
AutoTokenizer.from_pretrained("./", tokenizer_type="xxx")
@require_tokenizers
def test_tokenizer_identifier_with_correct_config(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: