[AutoTokenizer] Allow creation of tokenizers by tokenizer type (#13668)
* up * up
This commit is contained in:
committed by
GitHub
parent
2608944dc2
commit
8e908c8c74
5
tests/fixtures/merges.txt
vendored
Normal file
5
tests/fixtures/merges.txt
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
#version: 0.2
|
||||
Ġ l
|
||||
Ġl o
|
||||
Ġlo w
|
||||
e r
|
||||
1
tests/fixtures/vocab.json
vendored
Normal file
1
tests/fixtures/vocab.json
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"l": 0, "o": 1, "w": 2, "e": 3, "r": 4, "s": 5, "t": 6, "i": 7, "d": 8, "n": 9, "Ġ": 10, "Ġl": 11, "Ġn": 12, "Ġlo": 13, "Ġlow": 14, "er": 15, "Ġlowest": 16, "Ġnewer": 17, "Ġwider": 18, "<unk>": 19, "<|endoftext|>": 20}
|
||||
10
tests/fixtures/vocab.txt
vendored
Normal file
10
tests/fixtures/vocab.txt
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
[PAD]
|
||||
[SEP]
|
||||
[MASK]
|
||||
[CLS]
|
||||
[unused3]
|
||||
[unused4]
|
||||
[unused5]
|
||||
[unused6]
|
||||
[unused7]
|
||||
[unused8]
|
||||
@@ -13,9 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -78,6 +82,39 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
self.assertEqual(tokenizer.vocab_size, 12)
|
||||
|
||||
def test_tokenizer_from_type(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert", use_fast=False)
|
||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
|
||||
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2", use_fast=False)
|
||||
self.assertIsInstance(tokenizer, GPT2Tokenizer)
|
||||
|
||||
@require_tokenizers
|
||||
def test_tokenizer_from_type_fast(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.copy("./tests/fixtures/vocab.txt", os.path.join(tmp_dir, "vocab.txt"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="bert")
|
||||
self.assertIsInstance(tokenizer, BertTokenizerFast)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
shutil.copy("./tests/fixtures/vocab.json", os.path.join(tmp_dir, "vocab.json"))
|
||||
shutil.copy("./tests/fixtures/merges.txt", os.path.join(tmp_dir, "merges.txt"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, tokenizer_type="gpt2")
|
||||
self.assertIsInstance(tokenizer, GPT2TokenizerFast)
|
||||
|
||||
def test_tokenizer_from_type_incorrect_name(self):
|
||||
with pytest.raises(ValueError):
|
||||
AutoTokenizer.from_pretrained("./", tokenizer_type="xxx")
|
||||
|
||||
@require_tokenizers
|
||||
def test_tokenizer_identifier_with_correct_config(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||
|
||||
Reference in New Issue
Block a user