Re-add support for single url files in objects download (#19014)
This commit is contained in:
@@ -32,7 +32,9 @@ from .utils import (
|
|||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
|
is_remote_url,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -592,9 +594,12 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
||||||
# Soecial case when pretrained_model_name_or_path is a local file
|
# Special case when pretrained_model_name_or_path is a local file
|
||||||
resolved_config_file = pretrained_model_name_or_path
|
resolved_config_file = pretrained_model_name_or_path
|
||||||
is_local = True
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
configuration_file = pretrained_model_name_or_path
|
||||||
|
resolved_config_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ from .utils import (
|
|||||||
TensorType,
|
TensorType,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
download_url,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -386,6 +388,9 @@ class FeatureExtractionMixin(PushToHubMixin):
|
|||||||
if os.path.isfile(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
resolved_feature_extractor_file = pretrained_model_name_or_path
|
resolved_feature_extractor_file = pretrained_model_name_or_path
|
||||||
is_local = True
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
feature_extractor_file = pretrained_model_name_or_path
|
||||||
|
resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
feature_extractor_file = FEATURE_EXTRACTOR_NAME
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -47,8 +47,10 @@ from .utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
download_url,
|
||||||
has_file,
|
has_file,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -677,6 +679,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
is_local = True
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -54,9 +54,11 @@ from .utils import (
|
|||||||
ModelOutput,
|
ModelOutput,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
cached_file,
|
cached_file,
|
||||||
|
download_url,
|
||||||
find_labels,
|
find_labels,
|
||||||
has_file,
|
has_file,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
working_or_temp_dir,
|
working_or_temp_dir,
|
||||||
@@ -2345,6 +2347,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
is_local = True
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
# set correct filename
|
# set correct filename
|
||||||
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||||
|
|||||||
@@ -59,10 +59,12 @@ from .utils import (
|
|||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
download_url,
|
||||||
has_file,
|
has_file,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@@ -1998,6 +2000,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
|
||||||
is_local = True
|
is_local = True
|
||||||
|
elif is_remote_url(pretrained_model_name_or_path):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
# set correct filename
|
# set correct filename
|
||||||
if from_tf:
|
if from_tf:
|
||||||
|
|||||||
@@ -42,9 +42,11 @@ from .utils import (
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
cached_file,
|
cached_file,
|
||||||
copy_func,
|
copy_func,
|
||||||
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -1680,6 +1682,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
file_id = list(cls.vocab_files_names.keys())[0]
|
file_id = list(cls.vocab_files_names.keys())[0]
|
||||||
|
|
||||||
vocab_files[file_id] = pretrained_model_name_or_path
|
vocab_files[file_id] = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
|
||||||
@@ -1723,6 +1726,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
resolved_vocab_files[file_id] = None
|
resolved_vocab_files[file_id] = None
|
||||||
|
elif is_remote_url(file_path):
|
||||||
|
resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
|
||||||
else:
|
else:
|
||||||
resolved_vocab_files[file_id] = cached_file(
|
resolved_vocab_files[file_id] = cached_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from .hub import (
|
|||||||
cached_file,
|
cached_file,
|
||||||
default_cache_path,
|
default_cache_path,
|
||||||
define_sagemaker_information,
|
define_sagemaker_information,
|
||||||
|
download_url,
|
||||||
extract_commit_hash,
|
extract_commit_hash,
|
||||||
get_cached_models,
|
get_cached_models,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
@@ -70,6 +71,7 @@ from .hub import (
|
|||||||
has_file,
|
has_file,
|
||||||
http_user_agent,
|
http_user_agent,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
|
is_remote_url,
|
||||||
move_cache,
|
move_cache,
|
||||||
send_example_telemetry,
|
send_example_telemetry,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@@ -37,7 +39,7 @@ from huggingface_hub import (
|
|||||||
whoami,
|
whoami,
|
||||||
)
|
)
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
|
||||||
from huggingface_hub.utils import (
|
from huggingface_hub.utils import (
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
@@ -124,6 +126,11 @@ HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/tele
|
|||||||
_CACHED_NO_EXIST = object()
|
_CACHED_NO_EXIST = object()
|
||||||
|
|
||||||
|
|
||||||
|
def is_remote_url(url_or_filename):
|
||||||
|
parsed = urlparse(url_or_filename)
|
||||||
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
|
||||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||||
"""
|
"""
|
||||||
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
||||||
@@ -541,6 +548,32 @@ def get_file_from_repo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url, proxies=None):
|
||||||
|
"""
|
||||||
|
Downloads a given url in a temporary file. This function is not safe to use in multiple processes. Its only use is
|
||||||
|
for deprecated behavior allowing to download config/models with a single url instead of using the Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (`str`): The url of the file to download.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The location of the temporary file where the url was downloaded.
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
|
||||||
|
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
|
||||||
|
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
|
||||||
|
" multiple processes (each process will download the file in a different temporary file)."
|
||||||
|
)
|
||||||
|
tmp_file = tempfile.mktemp()
|
||||||
|
with open(tmp_file, "wb") as f:
|
||||||
|
http_get(url, f, proxies=proxies)
|
||||||
|
return tmp_file
|
||||||
|
|
||||||
|
|
||||||
def has_file(
|
def has_file(
|
||||||
path_or_repo: Union[str, os.PathLike],
|
path_or_repo: Union[str, os.PathLike],
|
||||||
filename: str,
|
filename: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user