Use new huggingface_hub tools for download models (#18438)
* Draft new cached_file * Initial draft for config and model * Small fixes * Fix first batch of tests * Look in cache when internet is down * Fix last tests * Bad black, not fixing all quality errors * Make diff less * Implement change for TF and Flax models * Add tokenizer and feature extractor * For compatibility with main * Add utils to move the cache and auto-do it at first use. * Quality * Deal with empty commit shas * Deal with empty etag * Address review comments
This commit is contained in:
@@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .utils import (
|
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
|
||||||
CONFIG_NAME,
|
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
EntryNotFoundError,
|
|
||||||
PushToHubMixin,
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
cached_path,
|
|
||||||
copy_func,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
|
||||||
is_remote_url,
|
|
||||||
is_torch_available,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -591,77 +575,43 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
user_agent["using_pipeline"] = from_pipeline
|
user_agent["using_pipeline"] = from_pipeline
|
||||||
|
|
||||||
if is_offline_mode() and not local_files_only:
|
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
|
||||||
local_files_only = True
|
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
|
||||||
pretrained_model_name_or_path
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
):
|
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
config_file = pretrained_model_name_or_path
|
# Soecial case when pretrained_model_name_or_path is a local file
|
||||||
|
resolved_config_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
else:
|
else:
|
||||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||||
|
|
||||||
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
|
try:
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
|
# Load from local folder or from cache or download from model Hub and cache
|
||||||
else:
|
resolved_config_file = cached_file(
|
||||||
config_file = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=configuration_file,
|
configuration_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
subfolder=subfolder,
|
||||||
mirror=None,
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||||
|
# the original exception.
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# For any other exception, we throw a generic error.
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
||||||
|
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
|
||||||
|
f" containing a {configuration_file} file"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
# Load from URL or cache if already cached
|
|
||||||
resolved_config_file = cached_path(
|
|
||||||
config_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
|
||||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
|
||||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
|
||||||
"`use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
|
||||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
|
||||||
"available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
|
||||||
)
|
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
|
||||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
|
||||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
|
||||||
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a {configuration_file} file"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load config dict
|
# Load config dict
|
||||||
@@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||||
)
|
)
|
||||||
|
|
||||||
if resolved_config_file == config_file:
|
if is_local:
|
||||||
logger.info(f"loading configuration file {config_file}")
|
logger.info(f"loading configuration file {resolved_config_file}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
||||||
|
|
||||||
return config_dict, kwargs
|
return config_dict, kwargs
|
||||||
|
|
||||||
|
|||||||
@@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .utils import (
|
from .utils import (
|
||||||
FEATURE_EXTRACTOR_NAME,
|
FEATURE_EXTRACTOR_NAME,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
EntryNotFoundError,
|
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
TensorType,
|
TensorType,
|
||||||
cached_path,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
hf_bucket_url,
|
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -388,64 +380,40 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|||||||
local_files_only = True
|
local_files_only = True
|
||||||
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
feature_extractor_file = pretrained_model_name_or_path
|
resolved_feature_extractor_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
else:
|
else:
|
||||||
feature_extractor_file = hf_bucket_url(
|
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
||||||
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
|
try:
|
||||||
)
|
# Load from local folder or from cache or download from model Hub and cache
|
||||||
|
resolved_feature_extractor_file = cached_file(
|
||||||
try:
|
pretrained_model_name_or_path,
|
||||||
# Load from URL or cache if already cached
|
feature_extractor_file,
|
||||||
resolved_feature_extractor_file = cached_path(
|
cache_dir=cache_dir,
|
||||||
feature_extractor_file,
|
force_download=force_download,
|
||||||
cache_dir=cache_dir,
|
proxies=proxies,
|
||||||
force_download=force_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
local_files_only=local_files_only,
|
||||||
resume_download=resume_download,
|
use_auth_token=use_auth_token,
|
||||||
local_files_only=local_files_only,
|
user_agent=user_agent,
|
||||||
use_auth_token=use_auth_token,
|
revision=revision,
|
||||||
user_agent=user_agent,
|
)
|
||||||
)
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||||
except RepositoryNotFoundError:
|
# the original exception.
|
||||||
raise EnvironmentError(
|
raise
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
except Exception:
|
||||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
# For any other exception, we throw a generic error.
|
||||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
raise EnvironmentError(
|
||||||
"`use_auth_token=True`."
|
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
|
||||||
)
|
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
except RevisionNotFoundError:
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
raise EnvironmentError(
|
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
)
|
||||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
|
||||||
"available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
|
|
||||||
)
|
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
|
||||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
|
||||||
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
|
|
||||||
" the library in offline mode at"
|
|
||||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
|
|
||||||
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a {FEATURE_EXTRACTOR_NAME} file"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load feature_extractor dict
|
# Load feature_extractor dict
|
||||||
@@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|||||||
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||||
)
|
)
|
||||||
|
|
||||||
if resolved_feature_extractor_file == feature_extractor_file:
|
if is_local:
|
||||||
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
|
logger.info(f"loading configuration file {resolved_feature_extractor_file}")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"loading feature extractor configuration file {feature_extractor_file} from cache at"
|
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
|
||||||
f" {resolved_feature_extractor_file}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return feature_extractor_dict, kwargs
|
return feature_extractor_dict, kwargs
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
|||||||
from flax.serialization import from_bytes, to_bytes
|
from flax.serialization import from_bytes, to_bytes
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
@@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
FLAX_WEIGHTS_INDEX_NAME,
|
FLAX_WEIGHTS_INDEX_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
EntryNotFoundError,
|
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
cached_path,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
identifier allowed by git.
|
identifier allowed by git.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
|
specify the folder name here.
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||||
@@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
_do_init = kwargs.pop("_do_init", True)
|
_do_init = kwargs.pop("_do_init", True)
|
||||||
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
@@ -665,65 +664,44 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path}."
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
||||||
archive_file = hf_bucket_url(
|
try:
|
||||||
pretrained_model_name_or_path,
|
# Load from URL or cache if already cached
|
||||||
filename=filename,
|
cached_file_kwargs = dict(
|
||||||
revision=revision,
|
cache_dir=cache_dir,
|
||||||
)
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
)
|
||||||
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||||
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
try:
|
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
||||||
resolved_archive_file = cached_path(
|
|
||||||
archive_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
|
||||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
|
||||||
"login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
|
||||||
"this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
if filename == FLAX_WEIGHTS_NAME:
|
|
||||||
try:
|
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
archive_file = hf_bucket_url(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
filename=FLAX_WEIGHTS_INDEX_NAME,
|
|
||||||
revision=revision,
|
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_path(
|
if resolved_archive_file is not None:
|
||||||
archive_file,
|
is_sharded = True
|
||||||
cache_dir=cache_dir,
|
if resolved_archive_file is None:
|
||||||
force_download=force_download,
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
proxies=proxies,
|
# message.
|
||||||
resume_download=resume_download,
|
has_file_kwargs = {
|
||||||
local_files_only=local_files_only,
|
"revision": revision,
|
||||||
use_auth_token=use_auth_token,
|
"proxies": proxies,
|
||||||
user_agent=user_agent,
|
"use_auth_token": use_auth_token,
|
||||||
)
|
}
|
||||||
is_sharded = True
|
|
||||||
except EntryNotFoundError:
|
|
||||||
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
|
||||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
@@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
else:
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
|
# to the original exception.
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# For any other exception, we throw a generic error.
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
|
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
|
||||||
f"{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
|
||||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
|
||||||
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
|
||||||
" internet connection or see how to run the library in offline mode at"
|
|
||||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
resolved_archive_file = archive_file
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format
|
|||||||
|
|
||||||
from huggingface_hub import Repository, list_repo_files
|
from huggingface_hub import Repository, list_repo_files
|
||||||
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
|
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
|
||||||
from requests import HTTPError
|
|
||||||
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
|
||||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||||
@@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin
|
|||||||
from .tf_utils import shape_list
|
from .tf_utils import shape_list
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
TF2_WEIGHTS_INDEX_NAME,
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
EntryNotFoundError,
|
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
cached_file,
|
||||||
RevisionNotFoundError,
|
|
||||||
cached_path,
|
|
||||||
find_labels,
|
find_labels,
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
working_or_temp_dir,
|
working_or_temp_dir,
|
||||||
@@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
Please refer to the mirror site for more information.
|
Please refer to the mirror site for more information.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
|
specify the folder name here.
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||||
@@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
subfolder = kwargs.pop("subfolder", "")
|
||||||
|
|
||||||
if trust_remote_code is True:
|
if trust_remote_code is True:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
# index of the files.
|
# index of the files.
|
||||||
is_sharded = False
|
is_sharded = False
|
||||||
sharded_metadata = None
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint in priority if from_pt
|
# Load from a PyTorch checkpoint in priority if from_pt
|
||||||
@@ -2232,68 +2230,43 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path}."
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
|
is_local = True
|
||||||
else:
|
else:
|
||||||
|
# set correct filename
|
||||||
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||||
archive_file = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=filename,
|
|
||||||
revision=revision,
|
|
||||||
mirror=mirror,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_archive_file = cached_path(
|
cached_file_kwargs = dict(
|
||||||
archive_file,
|
cache_dir=cache_dir,
|
||||||
cache_dir=cache_dir,
|
force_download=force_download,
|
||||||
force_download=force_download,
|
proxies=proxies,
|
||||||
proxies=proxies,
|
resume_download=resume_download,
|
||||||
resume_download=resume_download,
|
local_files_only=local_files_only,
|
||||||
local_files_only=local_files_only,
|
use_auth_token=use_auth_token,
|
||||||
use_auth_token=use_auth_token,
|
user_agent=user_agent,
|
||||||
user_agent=user_agent,
|
revision=revision,
|
||||||
)
|
subfolder=subfolder,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
)
|
||||||
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||||
raise EnvironmentError(
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
|
||||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
|
||||||
"login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
|
||||||
"this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
if filename == TF2_WEIGHTS_NAME:
|
|
||||||
try:
|
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
archive_file = hf_bucket_url(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
filename=TF2_WEIGHTS_INDEX_NAME,
|
|
||||||
revision=revision,
|
|
||||||
mirror=mirror,
|
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_path(
|
if resolved_archive_file is not None:
|
||||||
archive_file,
|
is_sharded = True
|
||||||
cache_dir=cache_dir,
|
if resolved_archive_file is None:
|
||||||
force_download=force_download,
|
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
is_sharded = True
|
|
||||||
except EntryNotFoundError:
|
|
||||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
|
||||||
# message.
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
@@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
|
||||||
)
|
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
|
||||||
f"{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
|
||||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
|
||||||
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
|
|
||||||
" connection or see how to run the library in offline mode at"
|
|
||||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
|
# to the original exception.
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# For any other exception, we throw a generic error.
|
||||||
|
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
|
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
|
)
|
||||||
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
resolved_archive_file = archive_file
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
resolved_archive_file, _ = get_checkpoint_shard_files(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from packaging import version
|
|||||||
from torch import Tensor, device, nn
|
from torch import Tensor, device, nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from requests import HTTPError
|
|
||||||
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
from transformers.utils.import_utils import is_sagemaker_mp_enabled
|
from transformers.utils.import_utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
@@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
TF_WEIGHTS_NAME,
|
TF_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
ContextManagers,
|
ContextManagers,
|
||||||
EntryNotFoundError,
|
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
cached_file,
|
||||||
RevisionNotFoundError,
|
|
||||||
cached_path,
|
|
||||||
copy_func,
|
copy_func,
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
|
if is_local:
|
||||||
if from_tf and os.path.isfile(
|
if from_tf and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||||
):
|
):
|
||||||
@@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
pretrained_model_name_or_path
|
|
||||||
):
|
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
|
is_local = True
|
||||||
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
|
||||||
if not from_tf:
|
if not from_tf:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"from_tf to True to load from this checkpoint."
|
"from_tf to True to load from this checkpoint."
|
||||||
)
|
)
|
||||||
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
||||||
|
is_local = True
|
||||||
else:
|
else:
|
||||||
# set correct filename
|
# set correct filename
|
||||||
if from_tf:
|
if from_tf:
|
||||||
@@ -1931,63 +1925,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME
|
filename = WEIGHTS_NAME
|
||||||
|
|
||||||
archive_file = hf_bucket_url(
|
try:
|
||||||
pretrained_model_name_or_path,
|
# Load from URL or cache if already cached
|
||||||
filename=filename,
|
cached_file_kwargs = dict(
|
||||||
revision=revision,
|
cache_dir=cache_dir,
|
||||||
mirror=mirror,
|
force_download=force_download,
|
||||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
proxies=proxies,
|
||||||
)
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
_raise_exceptions_for_missing_entries=False,
|
||||||
|
)
|
||||||
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
try:
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||||
# Load from URL or cache if already cached
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
resolved_archive_file = cached_path(
|
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||||
archive_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
|
||||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
|
||||||
"login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
|
||||||
"this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
if filename == WEIGHTS_NAME:
|
|
||||||
try:
|
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
archive_file = hf_bucket_url(
|
resolved_archive_file = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
filename=WEIGHTS_INDEX_NAME,
|
|
||||||
revision=revision,
|
|
||||||
mirror=mirror,
|
|
||||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
|
||||||
)
|
)
|
||||||
resolved_archive_file = cached_path(
|
if resolved_archive_file is not None:
|
||||||
archive_file,
|
is_sharded = True
|
||||||
cache_dir=cache_dir,
|
if resolved_archive_file is None:
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
is_sharded = True
|
|
||||||
except EntryNotFoundError:
|
|
||||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
# message.
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
@@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
|
||||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
else:
|
except EnvironmentError:
|
||||||
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
|
# to the original exception.
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
# For any other exception, we throw a generic error.
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
||||||
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
|
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
||||||
|
f" {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
except HTTPError as err:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
|
||||||
f"{err}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
|
||||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
|
||||||
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
|
|
||||||
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
|
|
||||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
|
||||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
|
||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
|
||||||
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
|
|
||||||
f"{FLAX_WEIGHTS_NAME}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
resolved_archive_file = archive_file
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
|
|||||||
@@ -35,21 +35,16 @@ from packaging import version
|
|||||||
from . import __version__
|
from . import __version__
|
||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .utils import (
|
from .utils import (
|
||||||
EntryNotFoundError,
|
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
TensorType,
|
TensorType,
|
||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
cached_path,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
hf_bucket_url,
|
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
init_configuration = {}
|
init_configuration = {}
|
||||||
|
|
||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
if len(cls.vocab_files_names) > 1:
|
if len(cls.vocab_files_names) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
|
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
|
||||||
@@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
|
||||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||||
}
|
}
|
||||||
vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
|
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||||
|
|
||||||
if "tokenizer_file" in vocab_files_target:
|
if "tokenizer_file" in vocab_files:
|
||||||
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
# Try to get the tokenizer config to see if there are versioned tokenizer files.
|
||||||
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
fast_tokenizer_file = FULL_TOKENIZER_FILE
|
||||||
resolved_config_file = get_file_from_repo(
|
resolved_config_file = get_file_from_repo(
|
||||||
@@ -1704,80 +1700,38 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
if resolved_config_file is not None:
|
if resolved_config_file is not None:
|
||||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||||
tokenizer_config = json.load(reader)
|
tokenizer_config = json.load(reader)
|
||||||
if "fast_tokenizer_files" in tokenizer_config:
|
if "fast_tokenizer_files" in tokenizer_config:
|
||||||
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
|
||||||
vocab_files_target["tokenizer_file"] = fast_tokenizer_file
|
vocab_files["tokenizer_file"] = fast_tokenizer_file
|
||||||
|
|
||||||
# Look for the tokenizer files
|
|
||||||
for file_id, file_name in vocab_files_target.items():
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
if subfolder is not None:
|
|
||||||
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
|
|
||||||
else:
|
|
||||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
|
|
||||||
if not os.path.exists(full_file_name):
|
|
||||||
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
|
|
||||||
full_file_name = None
|
|
||||||
else:
|
|
||||||
full_file_name = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=file_name,
|
|
||||||
subfolder=subfolder,
|
|
||||||
revision=revision,
|
|
||||||
mirror=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
vocab_files[file_id] = full_file_name
|
|
||||||
|
|
||||||
# Get files from url, cache, or disk depending on the case
|
# Get files from url, cache, or disk depending on the case
|
||||||
resolved_vocab_files = {}
|
resolved_vocab_files = {}
|
||||||
unresolved_files = []
|
unresolved_files = []
|
||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
|
print(file_id, file_path)
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
resolved_vocab_files[file_id] = None
|
resolved_vocab_files[file_id] = None
|
||||||
else:
|
else:
|
||||||
try:
|
resolved_vocab_files[file_id] = cached_file(
|
||||||
resolved_vocab_files[file_id] = cached_path(
|
pretrained_model_name_or_path,
|
||||||
file_path,
|
file_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
)
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
except FileNotFoundError as error:
|
_raise_exceptions_for_missing_entries=False,
|
||||||
if local_files_only:
|
_raise_exceptions_for_connection_errors=False,
|
||||||
unresolved_files.append(file_id)
|
)
|
||||||
else:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
|
||||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
|
||||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
|
||||||
"for this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
|
||||||
)
|
|
||||||
except EntryNotFoundError:
|
|
||||||
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
|
|
||||||
resolved_vocab_files[file_id] = None
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
|
|
||||||
resolved_vocab_files[file_id] = None
|
|
||||||
|
|
||||||
if len(unresolved_files) > 0:
|
if len(unresolved_files) > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
if file_id not in resolved_vocab_files:
|
if file_id not in resolved_vocab_files:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if file_path == resolved_vocab_files[file_id]:
|
if is_local:
|
||||||
logger.info(f"loading file {file_path}")
|
logger.info(f"loading file {file_path}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
|
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ from .hub import (
|
|||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
|
cached_file,
|
||||||
cached_path,
|
cached_path,
|
||||||
default_cache_path,
|
default_cache_path,
|
||||||
define_sagemaker_information,
|
define_sagemaker_information,
|
||||||
@@ -76,6 +77,7 @@ from .hub import (
|
|||||||
is_local_clone,
|
is_local_clone,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
move_cache,
|
||||||
send_example_telemetry,
|
send_example_telemetry,
|
||||||
url_to_filename,
|
url_to_filename,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ import fnmatch
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -34,9 +36,20 @@ from urllib.parse import urlparse
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
import requests
|
import requests
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, list_repo_files, whoami
|
from huggingface_hub import (
|
||||||
|
CommitOperationAdd,
|
||||||
|
HfFolder,
|
||||||
|
create_commit,
|
||||||
|
create_repo,
|
||||||
|
hf_hub_download,
|
||||||
|
list_repo_files,
|
||||||
|
whoami,
|
||||||
|
)
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from requests.models import Response
|
from requests.models import Response
|
||||||
from transformers.utils.logging import tqdm
|
from transformers.utils.logging import tqdm
|
||||||
@@ -385,21 +398,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
class RepositoryNotFoundError(HTTPError):
|
|
||||||
"""
|
|
||||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
|
||||||
not have access to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EntryNotFoundError(HTTPError):
|
|
||||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
|
||||||
|
|
||||||
|
|
||||||
class RevisionNotFoundError(HTTPError):
|
|
||||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
|
||||||
|
|
||||||
|
|
||||||
def _raise_for_status(response: Response):
|
def _raise_for_status(response: Response):
|
||||||
"""
|
"""
|
||||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||||
@@ -628,6 +626,213 @@ def get_from_cache(
|
|||||||
return cache_path
|
return cache_path
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||||
|
"""
|
||||||
|
Explores the cache to return the latest cached file for a given revision.
|
||||||
|
"""
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
|
||||||
|
model_id = repo_id.replace("/", "--")
|
||||||
|
model_cache = os.path.join(cache_dir, f"models--{model_id}")
|
||||||
|
if not os.path.isdir(model_cache):
|
||||||
|
# No cache for this model
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||||
|
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
||||||
|
if revision in cached_refs:
|
||||||
|
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||||
|
revision = f.read()
|
||||||
|
|
||||||
|
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||||
|
if revision not in cached_shas:
|
||||||
|
# No cache for this revision and we won't try to return a random revision
|
||||||
|
return None
|
||||||
|
|
||||||
|
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
|
||||||
|
return cached_file if os.path.isfile(cached_file) else None
|
||||||
|
|
||||||
|
|
||||||
|
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
|
||||||
|
# future.
|
||||||
|
LOCAL_FILES_ONLY_HF_ERROR = (
|
||||||
|
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
|
||||||
|
"look-ups and downloads online, set 'local_files_only' to False."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
|
||||||
|
# activate/deactivate progress bars.
|
||||||
|
@contextmanager
|
||||||
|
def _patch_hf_hub_tqdm():
|
||||||
|
"""
|
||||||
|
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
|
||||||
|
in logging.
|
||||||
|
"""
|
||||||
|
old_tqdm = huggingface_hub.file_download.tqdm
|
||||||
|
huggingface_hub.file_download.tqdm = tqdm
|
||||||
|
yield
|
||||||
|
huggingface_hub.file_download.tqdm = old_tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def cached_file(
|
||||||
|
path_or_repo_id: Union[str, os.PathLike],
|
||||||
|
filename: str,
|
||||||
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool = False,
|
||||||
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
subfolder: str = "",
|
||||||
|
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||||
|
_raise_exceptions_for_missing_entries=True,
|
||||||
|
_raise_exceptions_for_connection_errors=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_or_repo_id (`str` or `os.PathLike`):
|
||||||
|
This can be either:
|
||||||
|
|
||||||
|
- a string, the *model id* of a model repo on huggingface.co.
|
||||||
|
- a path to a *directory* potentially containing the file.
|
||||||
|
filename (`str`):
|
||||||
|
The name of the file to locate in `path_or_repo`.
|
||||||
|
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||||
|
cache should not be used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||||
|
exist.
|
||||||
|
resume_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
|
use_auth_token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
|
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, will only try to load the tokenizer configuration from local files.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
|
specify the folder name here.
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Download a model weight from the Hub and cache it.
|
||||||
|
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
|
||||||
|
```"""
|
||||||
|
if is_offline_mode() and not local_files_only:
|
||||||
|
logger.info("Offline mode: forcing local_files_only=True")
|
||||||
|
local_files_only = True
|
||||||
|
if subfolder is None:
|
||||||
|
subfolder = ""
|
||||||
|
|
||||||
|
path_or_repo_id = str(path_or_repo_id)
|
||||||
|
full_filename = os.path.join(subfolder, filename)
|
||||||
|
if os.path.isdir(path_or_repo_id):
|
||||||
|
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
|
||||||
|
if not os.path.isfile(resolved_file):
|
||||||
|
if _raise_exceptions_for_missing_entries:
|
||||||
|
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return resolved_file
|
||||||
|
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
|
if isinstance(cache_dir, Path):
|
||||||
|
cache_dir = str(cache_dir)
|
||||||
|
user_agent = http_user_agent(user_agent)
|
||||||
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
with _patch_hf_hub_tqdm():
|
||||||
|
resolved_file = hf_hub_download(
|
||||||
|
path_or_repo_id,
|
||||||
|
filename,
|
||||||
|
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
user_agent=user_agent,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
except RepositoryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||||
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||||
|
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||||
|
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||||
|
)
|
||||||
|
except RevisionNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||||
|
"for this model name. Check the model page at "
|
||||||
|
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||||
|
)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
if not _raise_exceptions_for_missing_entries:
|
||||||
|
return None
|
||||||
|
if revision is None:
|
||||||
|
revision = "main"
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
|
||||||
|
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
|
||||||
|
)
|
||||||
|
except HTTPError as err:
|
||||||
|
# First we try to see if we have a cached version (not up to date):
|
||||||
|
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||||
|
if resolved_file is not None:
|
||||||
|
return resolved_file
|
||||||
|
if not _raise_exceptions_for_connection_errors:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
||||||
|
except ValueError as err:
|
||||||
|
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
|
||||||
|
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
|
||||||
|
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Otherwise we try to see if we have a cached version (not up to date):
|
||||||
|
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||||
|
if resolved_file is not None:
|
||||||
|
return resolved_file
|
||||||
|
if not _raise_exceptions_for_connection_errors:
|
||||||
|
return None
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||||
|
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||||
|
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||||
|
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved_file
|
||||||
|
|
||||||
|
|
||||||
def get_file_from_repo(
|
def get_file_from_repo(
|
||||||
path_or_repo: Union[str, os.PathLike],
|
path_or_repo: Union[str, os.PathLike],
|
||||||
filename: str,
|
filename: str,
|
||||||
@@ -638,6 +843,7 @@ def get_file_from_repo(
|
|||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
subfolder: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||||
@@ -670,6 +876,9 @@ def get_file_from_repo(
|
|||||||
identifier allowed by git.
|
identifier allowed by git.
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
If `True`, will only try to load the tokenizer configuration from local files.
|
If `True`, will only try to load the tokenizer configuration from local files.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||||
|
specify the folder name here.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
@@ -689,47 +898,20 @@ def get_file_from_repo(
|
|||||||
# This model does not have a tokenizer config so the result will be None.
|
# This model does not have a tokenizer config so the result will be None.
|
||||||
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
|
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
|
||||||
```"""
|
```"""
|
||||||
if is_offline_mode() and not local_files_only:
|
return cached_file(
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
path_or_repo_id=path_or_repo,
|
||||||
local_files_only = True
|
filename=filename,
|
||||||
|
cache_dir=cache_dir,
|
||||||
path_or_repo = str(path_or_repo)
|
force_download=force_download,
|
||||||
if os.path.isdir(path_or_repo):
|
resume_download=resume_download,
|
||||||
resolved_file = os.path.join(path_or_repo, filename)
|
proxies=proxies,
|
||||||
return resolved_file if os.path.isfile(resolved_file) else None
|
use_auth_token=use_auth_token,
|
||||||
else:
|
revision=revision,
|
||||||
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
|
local_files_only=local_files_only,
|
||||||
|
subfolder=subfolder,
|
||||||
try:
|
_raise_exceptions_for_missing_entries=False,
|
||||||
# Load from URL or cache if already cached
|
_raise_exceptions_for_connection_errors=False,
|
||||||
resolved_file = cached_path(
|
)
|
||||||
resolved_file,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RepositoryNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{path_or_repo} is not a local folder and is not a valid model identifier "
|
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
|
||||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
|
||||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
|
||||||
)
|
|
||||||
except RevisionNotFoundError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
|
||||||
"for this model name. Check the model page at "
|
|
||||||
f"'https://huggingface.co/{path_or_repo}' for available revisions."
|
|
||||||
)
|
|
||||||
except EnvironmentError:
|
|
||||||
# The repo and revision exist, but the file does not or there was a connection error fetching it.
|
|
||||||
return None
|
|
||||||
|
|
||||||
return resolved_file
|
|
||||||
|
|
||||||
|
|
||||||
def has_file(
|
def has_file(
|
||||||
@@ -766,7 +948,7 @@ def has_file(
|
|||||||
|
|
||||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
|
||||||
try:
|
try:
|
||||||
_raise_for_status(r)
|
huggingface_hub.utils._errors._raise_for_status(r)
|
||||||
return True
|
return True
|
||||||
except RepositoryNotFoundError as e:
|
except RepositoryNotFoundError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
@@ -1196,3 +1378,183 @@ def get_checkpoint_shard_files(
|
|||||||
cached_filenames.append(cached_filename)
|
cached_filenames.append(cached_filename)
|
||||||
|
|
||||||
return cached_filenames, sharded_metadata
|
return cached_filenames, sharded_metadata
|
||||||
|
|
||||||
|
|
||||||
|
# All what is below is for conversion between old cache format and new cache format.
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_cached_files(cache_dir=None):
|
||||||
|
"""
|
||||||
|
Returns a list for all files cached with appropriate metadata.
|
||||||
|
"""
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
|
else:
|
||||||
|
cache_dir = str(cache_dir)
|
||||||
|
|
||||||
|
cached_files = []
|
||||||
|
for file in os.listdir(cache_dir):
|
||||||
|
meta_path = os.path.join(cache_dir, f"{file}.json")
|
||||||
|
if not os.path.isfile(meta_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as meta_file:
|
||||||
|
metadata = json.load(meta_file)
|
||||||
|
url = metadata["url"]
|
||||||
|
etag = metadata["etag"].replace('"', "")
|
||||||
|
cached_files.append({"file": file, "url": url, "etag": etag})
|
||||||
|
|
||||||
|
return cached_files
|
||||||
|
|
||||||
|
|
||||||
|
def get_hub_metadata(url, token=None):
|
||||||
|
"""
|
||||||
|
Returns the commit hash and associated etag for a given url.
|
||||||
|
"""
|
||||||
|
if token is None:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
headers = {"user-agent": http_user_agent()}
|
||||||
|
headers["authorization"] = f"Bearer {token}"
|
||||||
|
|
||||||
|
r = huggingface_hub.file_download._request_with_retry(
|
||||||
|
method="HEAD", url=url, headers=headers, allow_redirects=False
|
||||||
|
)
|
||||||
|
huggingface_hub.file_download._raise_for_status(r)
|
||||||
|
commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
|
||||||
|
etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
|
||||||
|
if etag is not None:
|
||||||
|
etag = huggingface_hub.file_download._normalize_etag(etag)
|
||||||
|
return etag, commit_hash
|
||||||
|
|
||||||
|
|
||||||
|
def extract_info_from_url(url):
|
||||||
|
"""
|
||||||
|
Extract repo_name, revision and filename from an url.
|
||||||
|
"""
|
||||||
|
search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
|
||||||
|
if search is None:
|
||||||
|
return None
|
||||||
|
repo, revision, filename = search.groups()
|
||||||
|
cache_repo = "--".join(["models"] + repo.split("/"))
|
||||||
|
return {"repo": cache_repo, "revision": revision, "filename": filename}
|
||||||
|
|
||||||
|
|
||||||
|
def clean_files_for(file):
|
||||||
|
"""
|
||||||
|
Remove, if they exist, file, file.json and file.lock
|
||||||
|
"""
|
||||||
|
for f in [file, f"{file}.json", f"{file}.lock"]:
|
||||||
|
if os.path.isfile(f):
|
||||||
|
os.remove(f)
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
|
||||||
|
"""
|
||||||
|
Move file to repo following the new huggingface hub cache organization.
|
||||||
|
"""
|
||||||
|
os.makedirs(repo, exist_ok=True)
|
||||||
|
|
||||||
|
# refs
|
||||||
|
os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
|
||||||
|
if revision != commit_hash:
|
||||||
|
ref_path = os.path.join(repo, "refs", revision)
|
||||||
|
with open(ref_path, "w") as f:
|
||||||
|
f.write(commit_hash)
|
||||||
|
|
||||||
|
# blobs
|
||||||
|
os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
|
||||||
|
# TODO: replace copy by move when all works well.
|
||||||
|
blob_path = os.path.join(repo, "blobs", etag)
|
||||||
|
shutil.move(file, blob_path)
|
||||||
|
|
||||||
|
# snapshots
|
||||||
|
os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
|
||||||
|
pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
|
||||||
|
huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
|
||||||
|
clean_files_for(file)
|
||||||
|
|
||||||
|
|
||||||
|
def move_cache(cache_dir=None, token=None):
|
||||||
|
if cache_dir is None:
|
||||||
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
|
if token is None:
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
cached_files = get_all_cached_files(cache_dir=cache_dir)
|
||||||
|
print(f"Moving {len(cached_files)} files to the new cache system")
|
||||||
|
|
||||||
|
hub_metadata = {}
|
||||||
|
for file_info in tqdm(cached_files):
|
||||||
|
url = file_info.pop("url")
|
||||||
|
if url not in hub_metadata:
|
||||||
|
try:
|
||||||
|
hub_metadata[url] = get_hub_metadata(url, token=token)
|
||||||
|
except requests.HTTPError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
etag, commit_hash = hub_metadata[url]
|
||||||
|
if etag is None or commit_hash is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if file_info["etag"] != etag:
|
||||||
|
# Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
|
||||||
|
clean_files_for(os.path.join(cache_dir, file_info["file"]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
url_info = extract_info_from_url(url)
|
||||||
|
if url_info is None:
|
||||||
|
# Not a file from huggingface.co
|
||||||
|
continue
|
||||||
|
|
||||||
|
repo = os.path.join(cache_dir, url_info["repo"])
|
||||||
|
move_to_new_cache(
|
||||||
|
file=os.path.join(cache_dir, file_info["file"]),
|
||||||
|
repo=repo,
|
||||||
|
filename=url_info["filename"],
|
||||||
|
revision=url_info["revision"],
|
||||||
|
etag=etag,
|
||||||
|
commit_hash=commit_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
|
||||||
|
if not os.path.isfile(cache_version_file):
|
||||||
|
cache_version = 0
|
||||||
|
else:
|
||||||
|
with open(cache_version_file) as f:
|
||||||
|
cache_version = int(f.read())
|
||||||
|
|
||||||
|
|
||||||
|
if cache_version < 1:
|
||||||
|
if is_offline_mode():
|
||||||
|
logger.warn(
|
||||||
|
"You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
|
||||||
|
"cache seems to be the one of a previous version. It is very likely that all your calls to any "
|
||||||
|
"`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
|
||||||
|
"your cache be updated automatically, then you can go back to offline mode."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"The cache for model files in Transformers v4.22.0 has been udpated. Migrating your old cache. This is a "
|
||||||
|
"one-time only operation. You can interrupt this and resume the migration later on by calling "
|
||||||
|
"`transformers.utils.move_cache()`."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
move_cache()
|
||||||
|
except Exception as e:
|
||||||
|
trace = "\n".join(traceback.format_tb(e.__traceback__))
|
||||||
|
logger.error(
|
||||||
|
f"There was a problem when trying to move your cache:\n\n{trace}\n\nPlease file an issue at "
|
||||||
|
"https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole message and we "
|
||||||
|
"will do our best to help."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
|
||||||
|
with open(cache_version_file, "w") as f:
|
||||||
|
f.write("1")
|
||||||
|
except Exception:
|
||||||
|
logger.warn(
|
||||||
|
f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
|
||||||
|
"the environment variable TRANSFORMERS_CACHE to a writable directory."
|
||||||
|
)
|
||||||
|
|||||||
@@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase):
|
|||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
response_mock.status_code = 500
|
response_mock.status_code = 500
|
||||||
response_mock.headers = []
|
response_mock.headers = {}
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
# Download this model to make sure it's in the cache.
|
||||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|||||||
@@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase):
|
|||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
response_mock.status_code = 500
|
response_mock.status_code = 500
|
||||||
response_mock.headers = []
|
response_mock.headers = {}
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
# Download this model to make sure it's in the cache.
|
||||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|||||||
@@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
response_mock.status_code = 500
|
response_mock.status_code = 500
|
||||||
response_mock.headers = []
|
response_mock.headers = {}
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
# Download this model to make sure it's in the cache.
|
||||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||||
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|||||||
@@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
response_mock.status_code = 500
|
response_mock.status_code = 500
|
||||||
response_mock.headers = []
|
response_mock.headers = {}
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
# Download this model to make sure it's in the cache.
|
||||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|||||||
@@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase):
|
|||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
response_mock.status_code = 500
|
response_mock.status_code = 500
|
||||||
response_mock.headers = []
|
response_mock.headers = {}
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
# Download this model to make sure it's in the cache.
|
||||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
|
with mock.patch("requests.request", return_value=response_mock) as mock_head:
|
||||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|||||||
Reference in New Issue
Block a user