Use commit hash to look in cache instead of calling head (#18534)
* Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address Julien's comments Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
@@ -27,7 +27,15 @@ from packaging import version
|
|||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
|
from .utils import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
PushToHubMixin,
|
||||||
|
cached_file,
|
||||||
|
copy_func,
|
||||||
|
extract_commit_hash,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -343,6 +351,8 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
# Name or path to the pretrained checkpoint
|
# Name or path to the pretrained checkpoint
|
||||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||||
|
# Config hash
|
||||||
|
self._commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
# Drop the transformers version info
|
# Drop the transformers version info
|
||||||
self.transformers_version = kwargs.pop("transformers_version", None)
|
self.transformers_version = kwargs.pop("transformers_version", None)
|
||||||
@@ -539,6 +549,8 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
original_kwargs = copy.deepcopy(kwargs)
|
original_kwargs = copy.deepcopy(kwargs)
|
||||||
# Get config dict associated with the base config file
|
# Get config dict associated with the base config file
|
||||||
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
if "_commit_hash" in config_dict:
|
||||||
|
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||||
|
|
||||||
# That config file may point us toward another config file to use.
|
# That config file may point us toward another config file to use.
|
||||||
if "configuration_files" in config_dict:
|
if "configuration_files" in config_dict:
|
||||||
@@ -564,6 +576,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -599,7 +612,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||||
# the original exception.
|
# the original exception.
|
||||||
@@ -616,6 +631,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
try:
|
try:
|
||||||
# Load config dict
|
# Load config dict
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||||
|
config_dict["_commit_hash"] = commit_hash
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||||
@@ -648,6 +664,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
# We remove them so they don't appear in `return_unused_kwargs`.
|
# We remove them so they don't appear in `return_unused_kwargs`.
|
||||||
kwargs.pop("_from_auto", None)
|
kwargs.pop("_from_auto", None)
|
||||||
kwargs.pop("_from_pipeline", None)
|
kwargs.pop("_from_pipeline", None)
|
||||||
|
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
||||||
|
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
||||||
|
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
||||||
|
|
||||||
config = cls(**config_dict)
|
config = cls(**config_dict)
|
||||||
|
|
||||||
@@ -751,6 +770,8 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
if "_auto_class" in output:
|
if "_auto_class" in output:
|
||||||
del output["_auto_class"]
|
del output["_auto_class"]
|
||||||
|
if "_commit_hash" in output:
|
||||||
|
del output["_commit_hash"]
|
||||||
|
|
||||||
# Transformers version when serializing the model
|
# Transformers version when serializing the model
|
||||||
output["transformers_version"] = __version__
|
output["transformers_version"] = __version__
|
||||||
|
|||||||
@@ -595,6 +595,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
_do_init = kwargs.pop("_do_init", True)
|
_do_init = kwargs.pop("_do_init", True)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -625,11 +626,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
_from_auto=from_auto_class,
|
_from_auto=from_auto_class,
|
||||||
_from_pipeline=from_pipeline,
|
_from_pipeline=from_pipeline,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
if commit_hash is None:
|
||||||
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
# Add the dtype to model_kwargs
|
# Add the dtype to model_kwargs
|
||||||
model_kwargs["dtype"] = dtype
|
model_kwargs["dtype"] = dtype
|
||||||
|
|
||||||
@@ -682,6 +687,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
_raise_exceptions_for_missing_entries=False,
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
@@ -748,6 +754,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
# init random models
|
# init random models
|
||||||
|
|||||||
@@ -2161,6 +2161,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -2191,11 +2192,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
_from_auto=from_auto_class,
|
_from_auto=from_auto_class,
|
||||||
_from_pipeline=from_pipeline,
|
_from_pipeline=from_pipeline,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
if commit_hash is None:
|
||||||
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
# index of the files.
|
# index of the files.
|
||||||
is_sharded = False
|
is_sharded = False
|
||||||
@@ -2253,6 +2258,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
_raise_exceptions_for_missing_entries=False,
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
@@ -2320,6 +2326,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|||||||
@@ -1840,6 +1840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
int8_threshold = kwargs.pop("int8_threshold", 6.0)
|
int8_threshold = kwargs.pop("int8_threshold", 6.0)
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1918,6 +1919,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
if commit_hash is None:
|
||||||
|
commit_hash = getattr(config, "_commit_hash", None)
|
||||||
|
|
||||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
# index of the files.
|
# index of the files.
|
||||||
is_sharded = False
|
is_sharded = False
|
||||||
@@ -2004,6 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
_raise_exceptions_for_missing_entries=False,
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
@@ -2078,6 +2083,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load pt weights early so that we know which dtype to init the model under
|
# load pt weights early so that we know which dtype to init the model under
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from ...dynamic_module_utils import get_class_from_dynamic_module
|
|||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
|
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
|
||||||
from ..encoder_decoder import EncoderDecoderConfig
|
from ..encoder_decoder import EncoderDecoderConfig
|
||||||
from .auto_factory import _LazyAutoMapping
|
from .auto_factory import _LazyAutoMapping
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
@@ -389,7 +389,8 @@ def get_tokenizer_config(
|
|||||||
tokenizer.save_pretrained("tokenizer-test")
|
tokenizer.save_pretrained("tokenizer-test")
|
||||||
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
tokenizer_config = get_tokenizer_config("tokenizer-test")
|
||||||
```"""
|
```"""
|
||||||
resolved_config_file = get_file_from_repo(
|
commit_hash = kwargs.get("_commit_hash", None)
|
||||||
|
resolved_config_file = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
TOKENIZER_CONFIG_FILE,
|
TOKENIZER_CONFIG_FILE,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@@ -399,13 +400,19 @@ def get_tokenizer_config(
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_raise_exceptions_for_connection_errors=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
if resolved_config_file is None:
|
if resolved_config_file is None:
|
||||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||||
return {}
|
return {}
|
||||||
|
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||||
|
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||||
return json.load(reader)
|
result = json.load(reader)
|
||||||
|
result["_commit_hash"] = commit_hash
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizer:
|
class AutoTokenizer:
|
||||||
@@ -532,6 +539,8 @@ class AutoTokenizer:
|
|||||||
|
|
||||||
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
|
||||||
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
|
||||||
|
if "_commit_hash" in tokenizer_config:
|
||||||
|
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
|
||||||
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
|
||||||
tokenizer_auto_map = None
|
tokenizer_auto_map = None
|
||||||
if "auto_map" in tokenizer_config:
|
if "auto_map" in tokenizer_config:
|
||||||
|
|||||||
@@ -557,7 +557,12 @@ def pipeline(
|
|||||||
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
||||||
# this is to keep BC).
|
# this is to keep BC).
|
||||||
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
|
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
|
||||||
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
|
hub_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
"trust_remote_code": trust_remote_code,
|
||||||
|
"_commit_hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
if task is None and model is None:
|
if task is None and model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -583,8 +588,10 @@ def pipeline(
|
|||||||
# Instantiate config if needed
|
# Instantiate config if needed
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||||
elif config is None and isinstance(model, str):
|
elif config is None and isinstance(model, str):
|
||||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||||
|
|
||||||
custom_tasks = {}
|
custom_tasks = {}
|
||||||
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
|
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
|
||||||
@@ -639,6 +646,7 @@ def pipeline(
|
|||||||
)
|
)
|
||||||
if config is None and isinstance(model, str):
|
if config is None and isinstance(model, str):
|
||||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||||
|
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||||
|
|
||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
@@ -672,6 +680,7 @@ def pipeline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_config = model.config
|
model_config = model.config
|
||||||
|
hub_kwargs["_commit_hash"] = model.config._commit_hash
|
||||||
|
|
||||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from pathlib import Path
|
|||||||
from typing import Iterator, List, Union
|
from typing import Iterator, List, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
from .deepspeed import is_deepspeed_available
|
from .deepspeed import is_deepspeed_available
|
||||||
@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
|
|||||||
raise SubprocessCallException(
|
raise SubprocessCallException(
|
||||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class RequestCounter:
|
||||||
|
"""
|
||||||
|
Helper class that will count all requests made online.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.head_request_count = 0
|
||||||
|
self.get_request_count = 0
|
||||||
|
self.other_request_count = 0
|
||||||
|
self.old_request = huggingface_hub.file_download.requests.request
|
||||||
|
huggingface_hub.file_download.requests.request = self.new_request
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
huggingface_hub.file_download.requests.request = self.old_request
|
||||||
|
|
||||||
|
def new_request(self, method, **kwargs):
|
||||||
|
if method == "GET":
|
||||||
|
self.get_request_count += 1
|
||||||
|
elif method == "HEAD":
|
||||||
|
self.head_request_count += 1
|
||||||
|
else:
|
||||||
|
self.other_request_count += 1
|
||||||
|
|
||||||
|
return self.old_request(method=method, **kwargs)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from .utils import (
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
get_file_from_repo,
|
extract_commit_hash,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
@@ -1651,6 +1651,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
commit_hash = kwargs.pop("_commit_hash", None)
|
||||||
|
|
||||||
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
|
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
|
||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
@@ -1690,7 +1691,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if "tokenizer_file" in vocab_files:
|
if "tokenizer_file" in vocab_files:
|
||||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||||
resolved_config_file = get_file_from_repo(
|
resolved_config_file = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
TOKENIZER_CONFIG_FILE,
|
TOKENIZER_CONFIG_FILE,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@@ -1701,7 +1702,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
_raise_exceptions_for_connection_errors=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
||||||
if resolved_config_file is not None:
|
if resolved_config_file is not None:
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||||
tokenizer_config = json.load(reader)
|
tokenizer_config = json.load(reader)
|
||||||
@@ -1730,7 +1736,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
_raise_exceptions_for_missing_entries=False,
|
_raise_exceptions_for_missing_entries=False,
|
||||||
_raise_exceptions_for_connection_errors=False,
|
_raise_exceptions_for_connection_errors=False,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
|
||||||
|
|
||||||
if len(unresolved_files) > 0:
|
if len(unresolved_files) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1763,6 +1771,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
_commit_hash=commit_hash,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1776,6 +1785,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
cache_dir=None,
|
cache_dir=None,
|
||||||
local_files_only=False,
|
local_files_only=False,
|
||||||
|
_commit_hash=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
|
||||||
@@ -1791,6 +1801,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
_commit_hash=_commit_hash,
|
||||||
**(copy.deepcopy(kwargs)),
|
**(copy.deepcopy(kwargs)),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1823,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
_commit_hash=_commit_hash,
|
||||||
)
|
)
|
||||||
config_tokenizer_class = config.tokenizer_class
|
config_tokenizer_class = config.tokenizer_class
|
||||||
except (OSError, ValueError, KeyError):
|
except (OSError, ValueError, KeyError):
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from .hub import (
|
|||||||
cached_file,
|
cached_file,
|
||||||
default_cache_path,
|
default_cache_path,
|
||||||
define_sagemaker_information,
|
define_sagemaker_information,
|
||||||
|
extract_commit_hash,
|
||||||
get_cached_models,
|
get_cached_models,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from huggingface_hub import (
|
|||||||
whoami,
|
whoami,
|
||||||
)
|
)
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||||
|
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers.utils.logging import tqdm
|
from transformers.utils.logging import tqdm
|
||||||
@@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
|
||||||
|
"""
|
||||||
|
Extracts the commit hash from a resolved filename toward a cache file.
|
||||||
|
"""
|
||||||
|
if resolved_file is None or commit_hash is not None:
|
||||||
|
return commit_hash
|
||||||
|
|
||||||
|
search = re.search(r"snapshots/([^/]+)/", resolved_file)
|
||||||
|
if search is None:
|
||||||
|
return None
|
||||||
|
commit_hash = search.groups()[0]
|
||||||
|
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
|
||||||
"""
|
"""
|
||||||
Explores the cache to return the latest cached file for a given revision.
|
Explores the cache to return the latest cached file for a given revision.
|
||||||
"""
|
"""
|
||||||
if revision is None:
|
if commit_hash is not None and revision is not None:
|
||||||
|
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
|
||||||
|
if revision is None and commit_hash is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
model_id = repo_id.replace("/", "--")
|
model_id = repo_id.replace("/", "--")
|
||||||
@@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
|||||||
if not os.path.isdir(os.path.join(model_cache, subfolder)):
|
if not os.path.isdir(os.path.join(model_cache, subfolder)):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
if commit_hash is None:
|
||||||
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||||
if revision in cached_refs:
|
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
||||||
with open(os.path.join(model_cache, "refs", revision)) as f:
|
if revision in cached_refs:
|
||||||
revision = f.read()
|
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||||
|
commit_hash = f.read()
|
||||||
|
|
||||||
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||||
if revision not in cached_shas:
|
if commit_hash not in cached_shas:
|
||||||
# No cache for this revision and we won't try to return a random revision
|
# No cache for this revision and we won't try to return a random revision
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
|
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
|
||||||
return cached_file if os.path.isfile(cached_file) else None
|
return cached_file if os.path.isfile(cached_file) else None
|
||||||
|
|
||||||
|
|
||||||
@@ -265,8 +283,9 @@ def cached_file(
|
|||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
subfolder: str = "",
|
subfolder: str = "",
|
||||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||||
_raise_exceptions_for_missing_entries=True,
|
_raise_exceptions_for_missing_entries: bool = True,
|
||||||
_raise_exceptions_for_connection_errors=True,
|
_raise_exceptions_for_connection_errors: bool = True,
|
||||||
|
_commit_hash: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||||
@@ -318,6 +337,13 @@ def cached_file(
|
|||||||
# Download a model weight from the Hub and cache it.
|
# Download a model weight from the Hub and cache it.
|
||||||
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
||||||
```"""
|
```"""
|
||||||
|
# Private arguments
|
||||||
|
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
|
||||||
|
# None.
|
||||||
|
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
|
||||||
|
# None.
|
||||||
|
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
|
||||||
|
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
|
||||||
if is_offline_mode() and not local_files_only:
|
if is_offline_mode() and not local_files_only:
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
local_files_only = True
|
local_files_only = True
|
||||||
@@ -339,6 +365,13 @@ def cached_file(
|
|||||||
cache_dir = TRANSFORMERS_CACHE
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
if isinstance(cache_dir, Path):
|
if isinstance(cache_dir, Path):
|
||||||
cache_dir = str(cache_dir)
|
cache_dir = str(cache_dir)
|
||||||
|
|
||||||
|
if _commit_hash is not None:
|
||||||
|
# If the file is cached under that commit hash, we return it directly.
|
||||||
|
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
|
||||||
|
if resolved_file is not None:
|
||||||
|
return resolved_file
|
||||||
|
|
||||||
user_agent = http_user_agent(user_agent)
|
user_agent = http_user_agent(user_agent)
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
@@ -803,6 +836,7 @@ def get_checkpoint_shard_files(
|
|||||||
user_agent=None,
|
user_agent=None,
|
||||||
revision=None,
|
revision=None,
|
||||||
subfolder="",
|
subfolder="",
|
||||||
|
_commit_hash=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
For a given model:
|
For a given model:
|
||||||
@@ -848,6 +882,7 @@ def get_checkpoint_shard_files(
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
_commit_hash=_commit_hash,
|
||||||
)
|
)
|
||||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||||
# we don't have to catch them here.
|
# we don't have to catch them here.
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
SMALL_MODEL_IDENTIFIER,
|
SMALL_MODEL_IDENTIFIER,
|
||||||
|
RequestCounter,
|
||||||
require_scatter,
|
require_scatter,
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
def test_model_from_flax_suggestion(self):
|
def test_model_from_flax_suggestion(self):
|
||||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
def test_cached_model_has_minimum_calls_to_head(self):
|
||||||
|
# Make sure we have cached the model.
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
self.assertEqual(counter.head_request_count, 1)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|
||||||
|
# With a sharded checkpoint
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
# There is no pytorch_model.bin so we still get one call for this one.
|
||||||
|
self.assertEqual(counter.head_request_count, 2)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
SMALL_MODEL_IDENTIFIER,
|
SMALL_MODEL_IDENTIFIER,
|
||||||
|
RequestCounter,
|
||||||
require_tensorflow_probability,
|
require_tensorflow_probability,
|
||||||
require_tf,
|
require_tf,
|
||||||
slow,
|
slow,
|
||||||
@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
def test_model_from_pt_suggestion(self):
|
def test_model_from_pt_suggestion(self):
|
||||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||||
|
|
||||||
|
def test_cached_model_has_minimum_calls_to_head(self):
|
||||||
|
# Make sure we have cached the model.
|
||||||
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
self.assertEqual(counter.head_request_count, 1)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|
||||||
|
# With a sharded checkpoint
|
||||||
|
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
# There is no pytorch_model.bin so we still get one call for this one.
|
||||||
|
self.assertEqual(counter.head_request_count, 2)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ from transformers.testing_utils import (
|
|||||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
SMALL_MODEL_IDENTIFIER,
|
SMALL_MODEL_IDENTIFIER,
|
||||||
|
RequestCounter,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
def test_get_tokenizer_config(self):
|
def test_get_tokenizer_config(self):
|
||||||
# Check we can load the tokenizer config of an online model.
|
# Check we can load the tokenizer config of an online model.
|
||||||
config = get_tokenizer_config("bert-base-cased")
|
config = get_tokenizer_config("bert-base-cased")
|
||||||
|
_ = config.pop("_commit_hash", None)
|
||||||
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
||||||
self.assertEqual(config, {"do_lower_case": False})
|
self.assertEqual(config, {"do_lower_case": False})
|
||||||
|
|
||||||
@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||||
):
|
):
|
||||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
|
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
||||||
|
# Make sure we have cached the tokenizer.
|
||||||
|
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
# We still have one extra call because the model does not have a added_tokens.json file
|
||||||
|
self.assertEqual(counter.head_request_count, 2)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
|||||||
TOKEN,
|
TOKEN,
|
||||||
USER,
|
USER,
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
|
RequestCounter,
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
[{"label": "LABEL_0", "score": 0.505}],
|
[{"label": "LABEL_0", "score": 0.505}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_cached_pipeline_has_minimum_calls_to_head(self):
|
||||||
|
# Make sure we have cached the pipeline.
|
||||||
|
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||||
|
self.assertEqual(counter.get_request_count, 0)
|
||||||
|
# We still have one extra call because the model does not have a added_tokens.json file
|
||||||
|
self.assertEqual(counter.head_request_count, 2)
|
||||||
|
self.assertEqual(counter.other_request_count, 0)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
config.push_to_hub("test-config", use_auth_token=self._token)
|
config.push_to_hub("test-config", use_auth_token=self._token)
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||||
for k, v in config.__dict__.items():
|
for k, v in config.to_dict().items():
|
||||||
if k != "transformers_version":
|
if k != "transformers_version":
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||||
for k, v in config.__dict__.items():
|
for k, v in config.to_dict().items():
|
||||||
if k != "transformers_version":
|
if k != "transformers_version":
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||||
for k, v in config.__dict__.items():
|
for k, v in config.to_dict().items():
|
||||||
if k != "transformers_version":
|
if k != "transformers_version":
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||||
for k, v in config.__dict__.items():
|
for k, v in config.to_dict().items():
|
||||||
if k != "transformers_version":
|
if k != "transformers_version":
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
|
|||||||
base_config = PretrainedConfig()
|
base_config = PretrainedConfig()
|
||||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||||
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
|
self.assertListEqual(
|
||||||
|
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
||||||
|
)
|
||||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||||
if len(keys_with_defaults) > 0:
|
if len(keys_with_defaults) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user