Clean up hub (#18497)
* Clean up utils.hub * Remove imports * More fixes * Last fix
This commit is contained in:
@@ -441,7 +441,6 @@ _import_structure = {
|
||||
"TensorType",
|
||||
"add_end_docstrings",
|
||||
"add_start_docstrings",
|
||||
"cached_path",
|
||||
"is_apex_available",
|
||||
"is_datasets_available",
|
||||
"is_faiss_available",
|
||||
@@ -3214,7 +3213,6 @@ if TYPE_CHECKING:
|
||||
TensorType,
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
cached_path,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
|
||||
@@ -38,7 +38,6 @@ from . import (
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WEIGHTS_NAME,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -91,11 +90,10 @@ from . import (
|
||||
XLMConfig,
|
||||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
cached_path,
|
||||
is_torch_available,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
)
|
||||
from .utils import hf_bucket_url, logging
|
||||
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(
|
||||
|
||||
# Initialise TF model
|
||||
if config_file in aws_config_map:
|
||||
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
|
||||
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
|
||||
config = config_class.from_json_file(config_file)
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
if pytorch_checkpoint_path in aws_config_map.keys():
|
||||
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
|
||||
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
|
||||
pytorch_checkpoint_path = cached_file(
|
||||
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
|
||||
)
|
||||
# Load PyTorch checkpoint in tf2 model:
|
||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
|
||||
@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
|
||||
print("-" * 100)
|
||||
|
||||
if config_shortcut_name in aws_config_map:
|
||||
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
|
||||
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
|
||||
else:
|
||||
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
|
||||
config_file = config_shortcut_name
|
||||
|
||||
if model_shortcut_name in aws_model_maps:
|
||||
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
|
||||
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
|
||||
else:
|
||||
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
||||
model_file = model_shortcut_name
|
||||
|
||||
if os.path.isfile(model_shortcut_name):
|
||||
model_shortcut_name = "converted_model"
|
||||
|
||||
@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import HfFolder, model_info
|
||||
|
||||
from .utils import (
|
||||
HF_MODULES_CACHE,
|
||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
logging,
|
||||
)
|
||||
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -219,18 +212,15 @@ def get_cached_module_file(
|
||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
||||
submodule = "local"
|
||||
else:
|
||||
module_file_or_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
|
||||
)
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_module_file = cached_path(
|
||||
module_file_or_url,
|
||||
resolved_module_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
module_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
|
||||
@@ -69,20 +69,14 @@ from .utils import (
|
||||
add_end_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_path,
|
||||
cached_property,
|
||||
copy_func,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
filename_to_url,
|
||||
get_cached_models,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
get_full_repo_name,
|
||||
get_list_of_files,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
http_get,
|
||||
http_user_agent,
|
||||
is_apex_available,
|
||||
is_coloredlogs_available,
|
||||
@@ -94,7 +88,6 @@ from .utils import (
|
||||
is_in_notebook,
|
||||
is_ipex_available,
|
||||
is_librosa_available,
|
||||
is_local_clone,
|
||||
is_offline_mode,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
@@ -105,7 +98,6 @@ from .utils import (
|
||||
is_pyctcdecode_available,
|
||||
is_pytesseract_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_remote_url,
|
||||
is_rjieba_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
@@ -141,5 +133,4 @@ from .utils import (
|
||||
torch_only_method,
|
||||
torch_required,
|
||||
torch_version,
|
||||
url_to_filename,
|
||||
)
|
||||
|
||||
@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import (
|
||||
)
|
||||
from .training_args import ParallelMode
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
MODEL_CARD_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
cached_file,
|
||||
is_datasets_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
@@ -153,11 +148,6 @@ class ModelCard:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
||||
|
||||
find_from_standard_name: (*optional*) boolean, default True:
|
||||
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
|
||||
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
|
||||
colocated modelcard.
|
||||
|
||||
return_unused_kwargs: (*optional*) bool:
|
||||
|
||||
- If False, then this function returns just the final model card object.
|
||||
@@ -168,21 +158,15 @@ class ModelCard:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
modelcard = ModelCard.from_pretrained(
|
||||
"bert-base-uncased"
|
||||
) # Download model card from huggingface.co and cache.
|
||||
modelcard = ModelCard.from_pretrained(
|
||||
"./test/saved_model/"
|
||||
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
|
||||
# Download model card from huggingface.co and cache.
|
||||
modelcard = ModelCard.from_pretrained("bert-base-uncased")
|
||||
# Model card was saved using *save_pretrained('./test/saved_model/')*
|
||||
modelcard = ModelCard.from_pretrained("./test/saved_model/")
|
||||
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
|
||||
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
||||
```"""
|
||||
# This imports every model so let's do it dynamically here.
|
||||
from transformers.models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
|
||||
@@ -190,31 +174,24 @@ class ModelCard:
|
||||
if from_pipeline is not None:
|
||||
user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
# For simplicity we use the same pretrained url than the configuration files
|
||||
# but with a different suffix (modelcard.json). This suffix is replaced below.
|
||||
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
model_card_file = pretrained_model_name_or_path
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_model_card_file = pretrained_model_name_or_path
|
||||
is_local = True
|
||||
else:
|
||||
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
|
||||
|
||||
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
||||
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
|
||||
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_model_card_file = cached_path(
|
||||
model_card_file, cache_dir=cache_dir, proxies=proxies, user_agent=user_agent
|
||||
resolved_model_card_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=MODEL_CARD_NAME,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if resolved_model_card_file == model_card_file:
|
||||
logger.info(f"loading model card file {model_card_file}")
|
||||
if is_local:
|
||||
logger.info(f"loading model card file {resolved_model_card_file}")
|
||||
else:
|
||||
logger.info(f"loading model card file {model_card_file} from cache at {resolved_model_card_file}")
|
||||
logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
|
||||
# Load model card
|
||||
modelcard = cls.from_json_file(resolved_model_card_file)
|
||||
|
||||
|
||||
@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
_fast_init = kwargs.pop("_fast_init", True)
|
||||
@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import numpy as np
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends
|
||||
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
|
||||
from .configuration_rag import RagConfig
|
||||
from .tokenization_rag import RagTokenizer
|
||||
|
||||
@@ -111,22 +111,21 @@ class LegacyIndex(Index):
|
||||
self._index_initialized = False
|
||||
|
||||
def _resolve_path(self, index_path, filename):
|
||||
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`."
|
||||
archive_file = os.path.join(index_path, filename)
|
||||
is_local = os.path.isdir(index_path)
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(archive_file)
|
||||
resolved_archive_file = cached_file(index_path, filename)
|
||||
except EnvironmentError:
|
||||
msg = (
|
||||
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
||||
f"Can't load '{filename}'. Make sure that:\n\n"
|
||||
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
|
||||
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading file {archive_file}")
|
||||
if is_local:
|
||||
logger.info(f"loading file {resolved_archive_file}")
|
||||
else:
|
||||
logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}")
|
||||
logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
|
||||
return resolved_archive_file
|
||||
|
||||
def _load_passages(self):
|
||||
|
||||
@@ -29,7 +29,7 @@ import numpy as np
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import (
|
||||
cached_path,
|
||||
cached_file,
|
||||
is_sacremoses_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
@@ -681,24 +681,21 @@ class TransfoXLCorpus(object):
|
||||
Instantiate a pre-processed corpus.
|
||||
"""
|
||||
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
|
||||
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||
resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
|
||||
f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
|
||||
f" was a path or url but couldn't find files {corpus_file} at this path or url."
|
||||
f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
|
||||
)
|
||||
return None
|
||||
if resolved_corpus_file == corpus_file:
|
||||
logger.info(f"loading corpus file {corpus_file}")
|
||||
if is_local:
|
||||
logger.info(f"loading corpus file {resolved_corpus_file}")
|
||||
else:
|
||||
logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}")
|
||||
logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
|
||||
|
||||
# Instantiate tokenizer.
|
||||
corpus = cls(*inputs, **kwargs)
|
||||
|
||||
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from numpy import isin
|
||||
|
||||
from huggingface_hub.file_download import http_get
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||
@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
|
||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging
|
||||
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_tf_available, is_torch_available, logging
|
||||
from .audio_classification import AudioClassificationPipeline
|
||||
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
||||
from .base import (
|
||||
|
||||
@@ -61,25 +61,16 @@ from .hub import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_file,
|
||||
cached_path,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
filename_to_url,
|
||||
get_cached_models,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
get_full_repo_name,
|
||||
get_list_of_files,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
http_get,
|
||||
http_user_agent,
|
||||
is_local_clone,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
move_cache,
|
||||
send_example_telemetry,
|
||||
url_to_filename,
|
||||
)
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
|
||||
@@ -14,44 +14,32 @@
|
||||
"""
|
||||
Hub utilities: utilities related to download and cache models
|
||||
"""
|
||||
import copy
|
||||
import fnmatch
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import (
|
||||
CommitOperationAdd,
|
||||
HfFolder,
|
||||
create_commit,
|
||||
create_repo,
|
||||
hf_hub_download,
|
||||
list_repo_files,
|
||||
hf_hub_url,
|
||||
whoami,
|
||||
)
|
||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests.exceptions import HTTPError
|
||||
from requests.models import Response
|
||||
from transformers.utils.logging import tqdm
|
||||
|
||||
from . import __version__, logging
|
||||
@@ -128,93 +116,6 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{
|
||||
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
|
||||
|
||||
|
||||
def is_remote_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(
|
||||
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
||||
) -> str:
|
||||
"""
|
||||
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
||||
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
||||
|
||||
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
||||
bandwidth costs).
|
||||
|
||||
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
||||
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
||||
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
||||
can't ever be stale.
|
||||
|
||||
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
||||
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
||||
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
||||
"""
|
||||
if subfolder is not None:
|
||||
filename = f"{subfolder}/{filename}"
|
||||
|
||||
if mirror:
|
||||
if mirror in ["tuna", "bfsu"]:
|
||||
raise ValueError("The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument.")
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{mirror}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{mirror}/{model_id}/{filename}"
|
||||
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
||||
|
||||
|
||||
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
||||
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
||||
identify it as a HDF5 file (see
|
||||
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||
"""
|
||||
url_bytes = url.encode("utf-8")
|
||||
filename = sha256(url_bytes).hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
filename += "." + sha256(etag_bytes).hexdigest()
|
||||
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def filename_to_url(filename, cache_dir=None):
|
||||
"""
|
||||
Return the url and etag (which may be `None`) stored for *filename*. Raise `EnvironmentError` if *filename* or its
|
||||
stored metadata do not exist.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(cache_path):
|
||||
raise EnvironmentError(f"file {cache_path} not found")
|
||||
|
||||
meta_path = cache_path + ".json"
|
||||
if not os.path.exists(meta_path):
|
||||
raise EnvironmentError(f"file {meta_path} not found")
|
||||
|
||||
with open(meta_path, encoding="utf-8") as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata["url"]
|
||||
etag = metadata["etag"]
|
||||
|
||||
return url, etag
|
||||
|
||||
|
||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||
"""
|
||||
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
||||
@@ -248,108 +149,6 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||
return cached_models
|
||||
|
||||
|
||||
def cached_path(
|
||||
url_or_filename,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
extract_compressed_file=False,
|
||||
force_extract=False,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
local_files_only=False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
|
||||
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
|
||||
then return the path
|
||||
|
||||
Args:
|
||||
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
||||
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
||||
resume_download: if True, resume the download if incompletely received file is found.
|
||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
||||
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
|
||||
will get token from ~/.huggingface.
|
||||
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
||||
file in a folder along the archive.
|
||||
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
||||
re-extract the archive and override the folder where it was extracted.
|
||||
|
||||
Return:
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
if is_remote_url(url_or_filename):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
output_path = get_from_cache(
|
||||
url_or_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
user_agent=user_agent,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
output_path = url_or_filename
|
||||
elif urlparse(url_or_filename).scheme == "":
|
||||
# File, but it doesn't exist.
|
||||
raise EnvironmentError(f"file {url_or_filename} not found")
|
||||
else:
|
||||
# Something unknown
|
||||
raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
|
||||
|
||||
if extract_compressed_file:
|
||||
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
||||
return output_path
|
||||
|
||||
# Path where we extract compressed archives
|
||||
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
||||
output_dir, output_file = os.path.split(output_path)
|
||||
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
||||
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
||||
|
||||
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
||||
return output_path_extracted
|
||||
|
||||
# Prevent parallel extractions
|
||||
lock_path = output_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
||||
os.makedirs(output_path_extracted)
|
||||
if is_zipfile(output_path):
|
||||
with ZipFile(output_path, "r") as zip_file:
|
||||
zip_file.extractall(output_path_extracted)
|
||||
zip_file.close()
|
||||
elif tarfile.is_tarfile(output_path):
|
||||
tar_file = tarfile.open(output_path)
|
||||
tar_file.extractall(output_path_extracted)
|
||||
tar_file.close()
|
||||
else:
|
||||
raise EnvironmentError(f"Archive format of {output_path} could not be identified")
|
||||
|
||||
return output_path_extracted
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def define_sagemaker_information():
|
||||
try:
|
||||
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
|
||||
@@ -399,234 +198,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def _raise_for_status(response: Response):
|
||||
"""
|
||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||
"""
|
||||
if "X-Error-Code" in response.headers:
|
||||
error_code = response.headers["X-Error-Code"]
|
||||
if error_code == "RepoNotFound":
|
||||
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
|
||||
elif error_code == "EntryNotFound":
|
||||
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
|
||||
elif error_code == "RevisionNotFound":
|
||||
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
|
||||
|
||||
if response.status_code == 401:
|
||||
# The repo was not found and the user is not Authenticated
|
||||
raise RepositoryNotFoundError(
|
||||
f"401 Client Error: Repository not found for url: {response.url}. "
|
||||
"If the repo is private, make sure you are authenticated."
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: BinaryIO,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
file_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = copy.deepcopy(headers)
|
||||
if resume_size > 0:
|
||||
headers["Range"] = f"bytes={resume_size}-"
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
_raise_for_status(r)
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(
|
||||
url: str,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
etag_timeout=10,
|
||||
resume_download=False,
|
||||
user_agent: Union[Dict, str, None] = None,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
local_files_only=False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
|
||||
path to the cached file.
|
||||
|
||||
Return:
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
if isinstance(use_auth_token, str):
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
elif use_auth_token:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
headers["authorization"] = f"Bearer {token}"
|
||||
|
||||
url_to_download = url
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||
_raise_for_status(r)
|
||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||
# We favor a custom header indicating the etag of the linked resource, and
|
||||
# we fallback to the regular etag header.
|
||||
# If we don't have any of those, raise an error.
|
||||
if etag is None:
|
||||
raise OSError(
|
||||
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
||||
)
|
||||
# In case of a redirect,
|
||||
# save an extra redirect on the request.get call,
|
||||
# and ensure we download the exact atomic version even if it changed
|
||||
# between the HEAD and the GET (unlikely, but hey).
|
||||
if 300 <= r.status_code <= 399:
|
||||
url_to_download = r.headers["Location"]
|
||||
except (
|
||||
requests.exceptions.SSLError,
|
||||
requests.exceptions.ProxyError,
|
||||
RepositoryNotFoundError,
|
||||
EntryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
):
|
||||
# Actually raise for those subclasses of ConnectionError
|
||||
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
|
||||
raise
|
||||
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||
# Otherwise, our Internet connection is down.
|
||||
# etag is None
|
||||
pass
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# etag is None == we don't have a connection or we passed local_files_only.
|
||||
# try to get the last downloaded one
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
return cache_path
|
||||
else:
|
||||
matching_files = [
|
||||
file
|
||||
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
||||
if not file.endswith(".json") and not file.endswith(".lock")
|
||||
]
|
||||
if len(matching_files) > 0:
|
||||
return os.path.join(cache_dir, matching_files[-1])
|
||||
else:
|
||||
# If files cannot be found and local_files_only=True,
|
||||
# the models might've been found if local_files_only=False
|
||||
# Notify the user about that
|
||||
if local_files_only:
|
||||
fname = url.split("/")[-1]
|
||||
raise EntryNotFoundError(
|
||||
f"Cannot find the requested file ({fname}) in the cached path and outgoing traffic has been"
|
||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||
" to False."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Connection error, and we cannot find the requested files in the cached path."
|
||||
" Please try again or make sure your Internet connection is on."
|
||||
)
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
return cache_path
|
||||
|
||||
# Prevent parallel downloads of the same file with a lock.
|
||||
lock_path = cache_path + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
# If the download just completed while the lock was activated.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
# Even if returning early like here, the lock will be released.
|
||||
return cache_path
|
||||
|
||||
if resume_download:
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager() -> "io.BufferedWriter":
|
||||
with open(incomplete_path, "ab") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
if os.path.exists(incomplete_path):
|
||||
resume_size = os.stat(incomplete_path).st_size
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
|
||||
|
||||
# The url_to_download might be messy, so we extract the file name from the original url.
|
||||
file_name = url.split("/")[-1]
|
||||
http_get(
|
||||
url_to_download,
|
||||
temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
file_name=file_name,
|
||||
)
|
||||
|
||||
logger.info(f"storing {url} in cache at {cache_path}")
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
||||
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
||||
umask = os.umask(0o666)
|
||||
os.umask(umask)
|
||||
os.chmod(cache_path, 0o666 & ~umask)
|
||||
|
||||
logger.info(f"creating metadata file for {cache_path}")
|
||||
meta = {"url": url, "etag": etag}
|
||||
meta_path = cache_path + ".json"
|
||||
with open(meta_path, "w") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision.
|
||||
@@ -919,7 +490,6 @@ def has_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
revision: Optional[str] = None,
|
||||
mirror: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
):
|
||||
@@ -936,7 +506,7 @@ def has_file(
|
||||
if os.path.isdir(path_or_repo):
|
||||
return os.path.isfile(os.path.join(path_or_repo, filename))
|
||||
|
||||
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
|
||||
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
|
||||
|
||||
headers = {"user-agent": http_user_agent()}
|
||||
if isinstance(use_auth_token, str):
|
||||
@@ -965,89 +535,6 @@ def has_file(
|
||||
return False
|
||||
|
||||
|
||||
def get_list_of_files(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Gets the list of files inside `path_or_repo`.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is not optimized, so calling it a lot may result in connection errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of files available in `path_or_repo`.
|
||||
"""
|
||||
path_or_repo = str(path_or_repo)
|
||||
# If path_or_repo is a folder, we just return what is inside (subdirectories included).
|
||||
if os.path.isdir(path_or_repo):
|
||||
list_of_files = []
|
||||
for path, dir_names, file_names in os.walk(path_or_repo):
|
||||
list_of_files.extend([os.path.join(path, f) for f in file_names])
|
||||
return list_of_files
|
||||
|
||||
# Can't grab the files if we are on offline mode.
|
||||
if is_offline_mode() or local_files_only:
|
||||
return []
|
||||
|
||||
# Otherwise we grab the token and use the list_repo_files method.
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
elif use_auth_token is True:
|
||||
token = HfFolder.get_token()
|
||||
else:
|
||||
token = None
|
||||
|
||||
try:
|
||||
return list_repo_files(path_or_repo, revision=revision, token=token)
|
||||
except HTTPError as e:
|
||||
raise ValueError(
|
||||
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
||||
) from e
|
||||
|
||||
|
||||
def is_local_clone(repo_path, repo_url):
|
||||
"""
|
||||
Checks if the folder in `repo_path` is a local clone of `repo_url`.
|
||||
"""
|
||||
# First double-check that `repo_path` is a git repo
|
||||
if not os.path.exists(os.path.join(repo_path, ".git")):
|
||||
return False
|
||||
test_git = subprocess.run("git branch".split(), cwd=repo_path)
|
||||
if test_git.returncode != 0:
|
||||
return False
|
||||
|
||||
# Then look at its remotes
|
||||
remotes = subprocess.run(
|
||||
"git remote -v".split(),
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
cwd=repo_path,
|
||||
).stdout
|
||||
|
||||
return repo_url in remotes.split()
|
||||
|
||||
|
||||
class PushToHubMixin:
|
||||
"""
|
||||
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
||||
@@ -1310,7 +797,6 @@ def get_checkpoint_shard_files(
|
||||
use_auth_token=None,
|
||||
user_agent=None,
|
||||
revision=None,
|
||||
mirror=None,
|
||||
subfolder="",
|
||||
):
|
||||
"""
|
||||
@@ -1343,18 +829,11 @@ def get_checkpoint_shard_files(
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
cached_filenames = []
|
||||
for shard_filename in shard_filenames:
|
||||
shard_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=shard_filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
cached_filename = cached_path(
|
||||
shard_url,
|
||||
cached_filename = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
shard_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -1362,6 +841,8 @@ def get_checkpoint_shard_files(
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here.
|
||||
|
||||
@@ -26,20 +26,13 @@ import transformers
|
||||
from transformers import * # noqa F406
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
find_labels,
|
||||
get_file_from_repo,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
|
||||
|
||||
|
||||
class GetFromCacheTests(unittest.TestCase):
|
||||
def test_bogus_url(self):
|
||||
# This lets us simulate no connection
|
||||
# as the error raised is the same
|
||||
# `ConnectionError`
|
||||
url = "https://bogus"
|
||||
with self.assertRaisesRegex(ValueError, "Connection error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_file_not_found(self):
|
||||
# Valid revision (None) but missing file.
|
||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_model_not_found_not_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
@unittest.skip("No authentication when testing against prod")
|
||||
def test_model_not_found_authenticated(self):
|
||||
# Invalid model id.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url, use_auth_token="hf_sometoken")
|
||||
# ^ TODO - if we decide to unskip this: use a real / functional token
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_standard_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
||||
|
||||
def test_standard_object_rev(self):
|
||||
# Same object, but different revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
||||
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
||||
|
||||
def test_lfs_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
|
||||
def test_has_file(self):
|
||||
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||
|
||||
@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [
|
||||
"absl", # External module
|
||||
"add_end_docstrings", # Internal, should never have been in the main init.
|
||||
"add_start_docstrings", # Internal, should never have been in the main init.
|
||||
"cached_path", # Internal used for downloading models.
|
||||
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
||||
"logger", # Internal logger
|
||||
"logging", # External module
|
||||
|
||||
Reference in New Issue
Block a user