Better logic for getting tokenizer config in AutoTokenizer (#14906)
* Better logic for getting tokenizer config in AutoTokenizer * Remove needless import * Remove debug statement * Address review comments
This commit is contained in:
@@ -50,6 +50,7 @@ from tqdm.auto import tqdm
|
|||||||
import requests
|
import requests
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
from transformers.utils.versions import importlib_metadata
|
from transformers.utils.versions import importlib_metadata
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
@@ -2100,7 +2101,13 @@ def get_list_of_files(
|
|||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
else:
|
else:
|
||||||
token = None
|
token = None
|
||||||
return list_repo_files(path_or_repo, revision=revision, token=token)
|
|
||||||
|
try:
|
||||||
|
return list_repo_files(path_or_repo, revision=revision, token=token)
|
||||||
|
except HTTPError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
class cached_property(property):
|
class cached_property(property):
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ import importlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
cached_path,
|
cached_path,
|
||||||
|
get_list_of_files,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
@@ -330,6 +332,16 @@ def get_tokenizer_config(
|
|||||||
logger.info("Offline mode: forcing local_files_only=True")
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
|
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
|
||||||
|
repo_files = get_list_of_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
revision=revision,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
|
||||||
|
return {}
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
||||||
@@ -350,7 +362,7 @@ def get_tokenizer_config(
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
except (EnvironmentError, ValueError):
|
except EnvironmentError:
|
||||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,9 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
def test_tokenizer_identifier_non_existent(self):
|
def test_tokenizer_identifier_non_existent(self):
|
||||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||||
with self.assertRaises(EnvironmentError):
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
||||||
|
):
|
||||||
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
||||||
|
|
||||||
def test_parents_and_children_in_mappings(self):
|
def test_parents_and_children_in_mappings(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user