From 676643c6d6c36927725560926a8e8f7714666d5b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 23 Dec 2021 14:18:07 -0500 Subject: [PATCH] Better logic for getting tokenizer config in AutoTokenizer (#14906) * Better logic for getting tokenizer config in AutoTokenizer * Remove needless import * Remove debug statement * Address review comments --- src/transformers/file_utils.py | 9 ++++++++- src/transformers/models/auto/tokenization_auto.py | 14 +++++++++++++- tests/test_tokenization_auto.py | 4 +++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index cf5df0c294..1178489949 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -50,6 +50,7 @@ from tqdm.auto import tqdm import requests from filelock import FileLock from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami +from requests.exceptions import HTTPError from transformers.utils.versions import importlib_metadata from . import __version__ @@ -2100,7 +2101,13 @@ def get_list_of_files( token = HfFolder.get_token() else: token = None - return list_repo_files(path_or_repo, revision=revision, token=token) + + try: + return list_repo_files(path_or_repo, revision=revision, token=token) + except HTTPError as e: + raise ValueError( + f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?" + ) from e class cached_property(property): diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c6106da8bb..68ecf76c7b 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -18,11 +18,13 @@ import importlib import json import os from collections import OrderedDict +from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from ...configuration_utils import PretrainedConfig from ...file_utils import ( cached_path, + get_list_of_files, hf_bucket_url, is_offline_mode, is_sentencepiece_available, @@ -330,6 +332,16 @@ def get_tokenizer_config( logger.info("Offline mode: forcing local_files_only=True") local_files_only = True + # Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier + repo_files = get_list_of_files( + pretrained_model_name_or_path, + revision=revision, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + ) + if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]: + return {} + pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) @@ -350,7 +362,7 @@ def get_tokenizer_config( use_auth_token=use_auth_token, ) - except (EnvironmentError, ValueError): + except EnvironmentError: logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") return {} diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 1d6df0cbfd..665ab7f4b5 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -149,7 +149,9 @@ class AutoTokenizerTest(unittest.TestCase): @require_tokenizers def test_tokenizer_identifier_non_existent(self): for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: - with self.assertRaises(EnvironmentError): + with self.assertRaisesRegex( + ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?" + ): _ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists") def test_parents_and_children_in_mappings(self):