AutoTokenizer: infer the class from the tokenizer config if possible (#12208)
* AutoTokenizer: infer the class from the tokenizer config if possible * Add tests * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.configuration_auto import AutoConfig
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
||||
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, get_tokenizer_config
|
||||
from transformers.models.roberta.configuration_roberta import RobertaConfig
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
@@ -129,3 +129,34 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(tokenizer.vocab_size, 30000)
|
||||
self.assertEqual(tokenizer.unk_token, "[UNK]")
|
||||
self.assertEqual(tokenizer.padding_side, "right")
|
||||
|
||||
def test_auto_tokenizer_from_local_folder(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer2 = AutoTokenizer.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertIsInstance(tokenizer2, tokenizer.__class__)
|
||||
self.assertEqual(tokenizer2.vocab_size, 12)
|
||||
|
||||
def test_get_tokenizer_config(self):
|
||||
# Check we can load the tokenizer config of an online model.
|
||||
config = get_tokenizer_config("bert-base-cased")
|
||||
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
||||
self.assertEqual(config, {"do_lower_case": False})
|
||||
|
||||
# This model does not have a tokenizer_config so we get back an empty dict.
|
||||
config = get_tokenizer_config(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertDictEqual(config, {})
|
||||
|
||||
# A tokenizer saved with `save_pretrained` always creates a tokenizer config.
|
||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir)
|
||||
config = get_tokenizer_config(tmp_dir)
|
||||
|
||||
# Check the class of the tokenizer was properly saved (note that it always saves the slow class).
|
||||
self.assertEqual(config["tokenizer_class"], "BertTokenizer")
|
||||
# Check other keys just to make sure the config was properly saved /reloaded.
|
||||
self.assertEqual(config["name_or_path"], SMALL_MODEL_IDENTIFIER)
|
||||
|
||||
Reference in New Issue
Block a user