Better logic for getting tokenizer config in AutoTokenizer (#14906)

* Better logic for getting tokenizer config in AutoTokenizer

* Remove needless import

* Remove debug statement

* Address review comments
This commit is contained in:
Sylvain Gugger
2021-12-23 14:18:07 -05:00
committed by GitHub
parent f566c6e3b7
commit 676643c6d6
3 changed files with 24 additions and 3 deletions

View File

@@ -50,6 +50,7 @@ from tqdm.auto import tqdm
import requests
from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError
from transformers.utils.versions import importlib_metadata
from . import __version__
@@ -2100,7 +2101,13 @@ def get_list_of_files(
token = HfFolder.get_token()
else:
token = None
return list_repo_files(path_or_repo, revision=revision, token=token)
try:
return list_repo_files(path_or_repo, revision=revision, token=token)
except HTTPError as e:
raise ValueError(
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
) from e
class cached_property(property):

View File

@@ -18,11 +18,13 @@ import importlib
import json
import os
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig
from ...file_utils import (
cached_path,
get_list_of_files,
hf_bucket_url,
is_offline_mode,
is_sentencepiece_available,
@@ -330,6 +332,16 @@ def get_tokenizer_config(
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
repo_files = get_list_of_files(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
return {}
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
@@ -350,7 +362,7 @@ def get_tokenizer_config(
use_auth_token=use_auth_token,
)
except (EnvironmentError, ValueError):
except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}

View File

@@ -149,7 +149,9 @@ class AutoTokenizerTest(unittest.TestCase):
@require_tokenizers
def test_tokenizer_identifier_non_existent(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
with self.assertRaises(EnvironmentError):
with self.assertRaisesRegex(
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?"
):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
def test_parents_and_children_in_mappings(self):