Use commit hash to look in cache instead of calling head (#18534)
* Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address Julien's comments Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
@@ -27,7 +27,15 @@ from packaging import version
|
||||
|
||||
from . import __version__
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
PushToHubMixin,
|
||||
cached_file,
|
||||
copy_func,
|
||||
extract_commit_hash,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -343,6 +351,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
# Config hash
|
||||
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
# Drop the transformers version info
|
||||
self.transformers_version = kwargs.pop("transformers_version", None)
|
||||
@@ -539,6 +549,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
original_kwargs = copy.deepcopy(kwargs)
|
||||
# Get config dict associated with the base config file
|
||||
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "_commit_hash" in config_dict:
|
||||
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||
|
||||
# That config file may point us toward another config file to use.
|
||||
if "configuration_files" in config_dict:
|
||||
@@ -564,6 +576,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -599,7 +612,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
@@ -616,6 +631,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
config_dict["_commit_hash"] = commit_hash
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
@@ -648,6 +664,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
# We remove them so they don't appear in `return_unused_kwargs`.
|
||||
kwargs.pop("_from_auto", None)
|
||||
kwargs.pop("_from_pipeline", None)
|
||||
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
||||
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
||||
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
@@ -751,6 +770,8 @@ class PretrainedConfig(PushToHubMixin):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
if "_auto_class" in output:
|
||||
del output["_auto_class"]
|
||||
if "_commit_hash" in output:
|
||||
del output["_commit_hash"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
||||
@@ -595,6 +595,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_do_init = kwargs.pop("_do_init", True)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -625,11 +626,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
revision=revision,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
_commit_hash=commit_hash,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
if commit_hash is None:
|
||||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
# Add the dtype to model_kwargs
|
||||
model_kwargs["dtype"] = dtype
|
||||
|
||||
@@ -682,6 +687,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
@@ -748,6 +754,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
# init random models
|
||||
|
||||
@@ -2161,6 +2161,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -2191,11 +2192,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
revision=revision,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
_commit_hash=commit_hash,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
if commit_hash is None:
|
||||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
is_sharded = False
|
||||
@@ -2253,6 +2258,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
@@ -2320,6 +2326,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@@ -1840,6 +1840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
int8_threshold = kwargs.pop("int8_threshold", 6.0)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
if trust_remote_code is True:
|
||||
logger.warning(
|
||||
@@ -1918,6 +1919,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
if commit_hash is None:
|
||||
commit_hash = getattr(config, "_commit_hash", None)
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
is_sharded = False
|
||||
@@ -2004,6 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
@@ -2078,6 +2083,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...dynamic_module_utils import get_class_from_dynamic_module
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
|
||||
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
|
||||
from ..encoder_decoder import EncoderDecoderConfig
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
@@ -389,7 +389,8 @@ def get_tokenizer_config(
|
||||
tokenizer.save_pretrained("tokenizer-test")
|
||||
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
||||
```"""
|
||||
resolved_config_file = get_file_from_repo(
|
||||
commit_hash = kwargs.get("_commit_hash", None)
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
@@ -399,13 +400,19 @@ def get_tokenizer_config(
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||
return {}
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
return json.load(reader)
|
||||
result = json.load(reader)
|
||||
result["_commit_hash"] = commit_hash
|
||||
return result
|
||||
|
||||
|
||||
class AutoTokenizer:
|
||||
@@ -532,6 +539,8 @@ class AutoTokenizer:
|
||||
|
||||
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
||||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||
if "_commit_hash" in tokenizer_config:
|
||||
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
|
||||
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
||||
tokenizer_auto_map = None
|
||||
if "auto_map" in tokenizer_config:
|
||||
|
||||
@@ -557,7 +557,12 @@ def pipeline(
|
||||
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
||||
# this is to keep BC).
|
||||
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
|
||||
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
|
||||
hub_kwargs = {
|
||||
"revision": revision,
|
||||
"use_auth_token": use_auth_token,
|
||||
"trust_remote_code": trust_remote_code,
|
||||
"_commit_hash": None,
|
||||
}
|
||||
|
||||
if task is None and model is None:
|
||||
raise RuntimeError(
|
||||
@@ -583,8 +588,10 @@ def pipeline(
|
||||
# Instantiate config if needed
|
||||
if isinstance(config, str):
|
||||
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||
elif config is None and isinstance(model, str):
|
||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||
|
||||
custom_tasks = {}
|
||||
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
|
||||
@@ -639,6 +646,7 @@ def pipeline(
|
||||
)
|
||||
if config is None and isinstance(model, str):
|
||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||
|
||||
if device_map is not None:
|
||||
if "device_map" in model_kwargs:
|
||||
@@ -672,6 +680,7 @@ def pipeline(
|
||||
)
|
||||
|
||||
model_config = model.config
|
||||
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
||||
|
||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||
|
||||
@@ -31,6 +31,7 @@ from pathlib import Path
|
||||
from typing import Iterator, List, Union
|
||||
from unittest import mock
|
||||
|
||||
import huggingface_hub
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .deepspeed import is_deepspeed_available
|
||||
@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
|
||||
|
||||
class RequestCounter:
|
||||
"""
|
||||
Helper class that will count all requests made online.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.head_request_count = 0
|
||||
self.get_request_count = 0
|
||||
self.other_request_count = 0
|
||||
self.old_request = huggingface_hub.file_download.requests.request
|
||||
huggingface_hub.file_download.requests.request = self.new_request
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
huggingface_hub.file_download.requests.request = self.old_request
|
||||
|
||||
def new_request(self, method, **kwargs):
|
||||
if method == "GET":
|
||||
self.get_request_count += 1
|
||||
elif method == "HEAD":
|
||||
self.head_request_count += 1
|
||||
else:
|
||||
self.other_request_count += 1
|
||||
|
||||
return self.old_request(method=method, **kwargs)
|
||||
|
||||
@@ -42,7 +42,7 @@ from .utils import (
|
||||
add_end_docstrings,
|
||||
cached_file,
|
||||
copy_func,
|
||||
get_file_from_repo,
|
||||
extract_commit_hash,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
is_tf_available,
|
||||
@@ -1651,6 +1651,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
|
||||
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
|
||||
if from_pipeline is not None:
|
||||
@@ -1690,7 +1691,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if "tokenizer_file" in vocab_files:
|
||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||
resolved_config_file = get_file_from_repo(
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
TOKENIZER_CONFIG_FILE,
|
||||
cache_dir=cache_dir,
|
||||
@@ -1701,7 +1702,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||
if resolved_config_file is not None:
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
tokenizer_config = json.load(reader)
|
||||
@@ -1730,7 +1736,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
|
||||
|
||||
if len(unresolved_files) > 0:
|
||||
logger.info(
|
||||
@@ -1763,6 +1771,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
_commit_hash=commit_hash,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1776,6 +1785,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
use_auth_token=None,
|
||||
cache_dir=None,
|
||||
local_files_only=False,
|
||||
_commit_hash=None,
|
||||
**kwargs
|
||||
):
|
||||
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
||||
@@ -1791,6 +1801,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
_commit_hash=_commit_hash,
|
||||
**(copy.deepcopy(kwargs)),
|
||||
)
|
||||
else:
|
||||
@@ -1823,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
config_tokenizer_class = config.tokenizer_class
|
||||
except (OSError, ValueError, KeyError):
|
||||
|
||||
@@ -63,6 +63,7 @@ from .hub import (
|
||||
cached_file,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
extract_commit_hash,
|
||||
get_cached_models,
|
||||
get_file_from_repo,
|
||||
get_full_repo_name,
|
||||
|
||||
@@ -38,6 +38,7 @@ from huggingface_hub import (
|
||||
whoami,
|
||||
)
|
||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.utils.logging import tqdm
|
||||
@@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
|
||||
"""
|
||||
Extracts the commit hash from a resolved filename toward a cache file.
|
||||
"""
|
||||
if resolved_file is None or commit_hash is not None:
|
||||
return commit_hash
|
||||
|
||||
search = re.search(r"snapshots/([^/]+)/", resolved_file)
|
||||
if search is None:
|
||||
return None
|
||||
commit_hash = search.groups()[0]
|
||||
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||
|
||||
|
||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision.
|
||||
"""
|
||||
if revision is None:
|
||||
if commit_hash is not None and revision is not None:
|
||||
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
|
||||
if revision is None and commit_hash is None:
|
||||
revision = "main"
|
||||
|
||||
model_id = repo_id.replace("/", "--")
|
||||
@@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||
if not os.path.isdir(os.path.join(model_cache, subfolder)):
|
||||
return None
|
||||
|
||||
if commit_hash is None:
|
||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
||||
if revision in cached_refs:
|
||||
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||
revision = f.read()
|
||||
commit_hash = f.read()
|
||||
|
||||
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||
if revision not in cached_shas:
|
||||
if commit_hash not in cached_shas:
|
||||
# No cache for this revision and we won't try to return a random revision
|
||||
return None
|
||||
|
||||
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
|
||||
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
|
||||
return cached_file if os.path.isfile(cached_file) else None
|
||||
|
||||
|
||||
@@ -265,8 +283,9 @@ def cached_file(
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_missing_entries=True,
|
||||
_raise_exceptions_for_connection_errors=True,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
_commit_hash: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
@@ -318,6 +337,13 @@ def cached_file(
|
||||
# Download a model weight from the Hub and cache it.
|
||||
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
||||
```"""
|
||||
# Private arguments
|
||||
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
|
||||
# None.
|
||||
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
|
||||
# None.
|
||||
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
|
||||
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
@@ -339,6 +365,13 @@ def cached_file(
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if _commit_hash is not None:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
|
||||
user_agent = http_user_agent(user_agent)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
@@ -803,6 +836,7 @@ def get_checkpoint_shard_files(
|
||||
user_agent=None,
|
||||
revision=None,
|
||||
subfolder="",
|
||||
_commit_hash=None,
|
||||
):
|
||||
"""
|
||||
For a given model:
|
||||
@@ -848,6 +882,7 @@ def get_checkpoint_shard_files(
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here.
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_scatter,
|
||||
require_torch,
|
||||
slow,
|
||||
@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
|
||||
def test_model_from_flax_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
def test_cached_model_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the model.
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_tensorflow_probability,
|
||||
require_tf,
|
||||
slow,
|
||||
@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
def test_model_from_pt_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
def test_cached_model_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the model.
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -48,6 +48,7 @@ from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_tokenizers,
|
||||
slow,
|
||||
)
|
||||
@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
def test_get_tokenizer_config(self):
|
||||
# Check we can load the tokenizer config of an online model.
|
||||
config = get_tokenizer_config("bert-base-cased")
|
||||
_ = config.pop("_commit_hash", None)
|
||||
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
||||
self.assertEqual(config, {"do_lower_case": False})
|
||||
|
||||
@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the tokenizer.
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
RequestCounter,
|
||||
is_pipeline_test,
|
||||
is_staging_test,
|
||||
nested_simplify,
|
||||
@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
[{"label": "LABEL_0", "score": 0.505}],
|
||||
)
|
||||
|
||||
def test_cached_pipeline_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the pipeline.
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.push_to_hub("test-config", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
base_config = PretrainedConfig()
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
|
||||
self.assertListEqual(
|
||||
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
||||
)
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user