Clean up hub (#18497)
* Clean up utils.hub * Remove imports * More fixes * Last fix
This commit is contained in:
@@ -441,7 +441,6 @@ _import_structure = {
|
|||||||
"TensorType",
|
"TensorType",
|
||||||
"add_end_docstrings",
|
"add_end_docstrings",
|
||||||
"add_start_docstrings",
|
"add_start_docstrings",
|
||||||
"cached_path",
|
|
||||||
"is_apex_available",
|
"is_apex_available",
|
||||||
"is_datasets_available",
|
"is_datasets_available",
|
||||||
"is_faiss_available",
|
"is_faiss_available",
|
||||||
@@ -3214,7 +3213,6 @@ if TYPE_CHECKING:
|
|||||||
TensorType,
|
TensorType,
|
||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
cached_path,
|
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from . import (
|
|||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
WEIGHTS_NAME,
|
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@@ -91,11 +90,10 @@ from . import (
|
|||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
cached_path,
|
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
)
|
)
|
||||||
from .utils import hf_bucket_url, logging
|
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(
|
|||||||
|
|
||||||
# Initialise TF model
|
# Initialise TF model
|
||||||
if config_file in aws_config_map:
|
if config_file in aws_config_map:
|
||||||
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
|
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
|
||||||
config = config_class.from_json_file(config_file)
|
config = config_class.from_json_file(config_file)
|
||||||
config.output_hidden_states = True
|
config.output_hidden_states = True
|
||||||
config.output_attentions = True
|
config.output_attentions = True
|
||||||
@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(
|
|||||||
|
|
||||||
# Load weights from tf checkpoint
|
# Load weights from tf checkpoint
|
||||||
if pytorch_checkpoint_path in aws_config_map.keys():
|
if pytorch_checkpoint_path in aws_config_map.keys():
|
||||||
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
|
pytorch_checkpoint_path = cached_file(
|
||||||
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
|
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
|
||||||
|
)
|
||||||
# Load PyTorch checkpoint in tf2 model:
|
# Load PyTorch checkpoint in tf2 model:
|
||||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||||
|
|
||||||
@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
|
|||||||
print("-" * 100)
|
print("-" * 100)
|
||||||
|
|
||||||
if config_shortcut_name in aws_config_map:
|
if config_shortcut_name in aws_config_map:
|
||||||
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
|
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
|
||||||
else:
|
else:
|
||||||
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
|
config_file = config_shortcut_name
|
||||||
|
|
||||||
if model_shortcut_name in aws_model_maps:
|
if model_shortcut_name in aws_model_maps:
|
||||||
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
|
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
|
||||||
else:
|
else:
|
||||||
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
|
model_file = model_shortcut_name
|
||||||
|
|
||||||
if os.path.isfile(model_shortcut_name):
|
if os.path.isfile(model_shortcut_name):
|
||||||
model_shortcut_name = "converted_model"
|
model_shortcut_name = "converted_model"
|
||||||
|
|||||||
@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union
|
|||||||
|
|
||||||
from huggingface_hub import HfFolder, model_info
|
from huggingface_hub import HfFolder, model_info
|
||||||
|
|
||||||
from .utils import (
|
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
|
||||||
HF_MODULES_CACHE,
|
|
||||||
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
|
||||||
cached_path,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_offline_mode,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -219,18 +212,15 @@ def get_cached_module_file(
|
|||||||
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
|
|
||||||
submodule = "local"
|
submodule = "local"
|
||||||
else:
|
else:
|
||||||
module_file_or_url = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
|
|
||||||
)
|
|
||||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_module_file = cached_path(
|
resolved_module_file = cached_file(
|
||||||
module_file_or_url,
|
pretrained_model_name_or_path,
|
||||||
|
module_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
|
|||||||
@@ -69,20 +69,14 @@ from .utils import (
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
cached_path,
|
|
||||||
cached_property,
|
cached_property,
|
||||||
copy_func,
|
copy_func,
|
||||||
default_cache_path,
|
default_cache_path,
|
||||||
define_sagemaker_information,
|
define_sagemaker_information,
|
||||||
filename_to_url,
|
|
||||||
get_cached_models,
|
get_cached_models,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_from_cache,
|
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
get_list_of_files,
|
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
http_get,
|
|
||||||
http_user_agent,
|
http_user_agent,
|
||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_coloredlogs_available,
|
is_coloredlogs_available,
|
||||||
@@ -94,7 +88,6 @@ from .utils import (
|
|||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
is_ipex_available,
|
is_ipex_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_local_clone,
|
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
@@ -105,7 +98,6 @@ from .utils import (
|
|||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
is_pytesseract_available,
|
is_pytesseract_available,
|
||||||
is_pytorch_quantization_available,
|
is_pytorch_quantization_available,
|
||||||
is_remote_url,
|
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
@@ -141,5 +133,4 @@ from .utils import (
|
|||||||
torch_only_method,
|
torch_only_method,
|
||||||
torch_required,
|
torch_required,
|
||||||
torch_version,
|
torch_version,
|
||||||
url_to_filename,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import (
|
|||||||
)
|
)
|
||||||
from .training_args import ParallelMode
|
from .training_args import ParallelMode
|
||||||
from .utils import (
|
from .utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
MODEL_CARD_NAME,
|
MODEL_CARD_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
cached_file,
|
||||||
WEIGHTS_NAME,
|
|
||||||
cached_path,
|
|
||||||
hf_bucket_url,
|
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -153,11 +148,6 @@ class ModelCard:
|
|||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
|
||||||
|
|
||||||
find_from_standard_name: (*optional*) boolean, default True:
|
|
||||||
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
|
|
||||||
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
|
|
||||||
colocated modelcard.
|
|
||||||
|
|
||||||
return_unused_kwargs: (*optional*) bool:
|
return_unused_kwargs: (*optional*) bool:
|
||||||
|
|
||||||
- If False, then this function returns just the final model card object.
|
- If False, then this function returns just the final model card object.
|
||||||
@@ -168,21 +158,15 @@ class ModelCard:
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
modelcard = ModelCard.from_pretrained(
|
# Download model card from huggingface.co and cache.
|
||||||
"bert-base-uncased"
|
modelcard = ModelCard.from_pretrained("bert-base-uncased")
|
||||||
) # Download model card from huggingface.co and cache.
|
# Model card was saved using *save_pretrained('./test/saved_model/')*
|
||||||
modelcard = ModelCard.from_pretrained(
|
modelcard = ModelCard.from_pretrained("./test/saved_model/")
|
||||||
"./test/saved_model/"
|
|
||||||
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
|
|
||||||
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
|
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
|
||||||
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
||||||
```"""
|
```"""
|
||||||
# This imports every model so let's do it dynamically here.
|
|
||||||
from transformers.models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
|
|
||||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
|
|
||||||
@@ -190,37 +174,30 @@ class ModelCard:
|
|||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
user_agent["using_pipeline"] = from_pipeline
|
user_agent["using_pipeline"] = from_pipeline
|
||||||
|
|
||||||
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
# For simplicity we use the same pretrained url than the configuration files
|
if os.path.isfile(pretrained_model_name_or_path):
|
||||||
# but with a different suffix (modelcard.json). This suffix is replaced below.
|
resolved_model_card_file = pretrained_model_name_or_path
|
||||||
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
is_local = True
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
|
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
|
||||||
model_card_file = pretrained_model_name_or_path
|
|
||||||
else:
|
else:
|
||||||
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
|
try:
|
||||||
|
# Load from URL or cache if already cached
|
||||||
|
resolved_model_card_file = cached_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=MODEL_CARD_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
proxies=proxies,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
if is_local:
|
||||||
|
logger.info(f"loading model card file {resolved_model_card_file}")
|
||||||
|
else:
|
||||||
|
logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
|
||||||
|
# Load model card
|
||||||
|
modelcard = cls.from_json_file(resolved_model_card_file)
|
||||||
|
|
||||||
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
except (EnvironmentError, json.JSONDecodeError):
|
||||||
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
# We fall back on creating an empty model card
|
||||||
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
|
modelcard = cls()
|
||||||
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load from URL or cache if already cached
|
|
||||||
resolved_model_card_file = cached_path(
|
|
||||||
model_card_file, cache_dir=cache_dir, proxies=proxies, user_agent=user_agent
|
|
||||||
)
|
|
||||||
if resolved_model_card_file == model_card_file:
|
|
||||||
logger.info(f"loading model card file {model_card_file}")
|
|
||||||
else:
|
|
||||||
logger.info(f"loading model card file {model_card_file} from cache at {resolved_model_card_file}")
|
|
||||||
# Load model card
|
|
||||||
modelcard = cls.from_json_file(resolved_model_card_file)
|
|
||||||
|
|
||||||
except (EnvironmentError, json.JSONDecodeError):
|
|
||||||
# We fall back on creating an empty model card
|
|
||||||
modelcard = cls()
|
|
||||||
|
|
||||||
# Update model card with kwargs if needed
|
# Update model card with kwargs if needed
|
||||||
to_remove = []
|
to_remove = []
|
||||||
|
|||||||
@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
# message.
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
"mirror": mirror,
|
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"use_auth_token": use_auth_token,
|
"use_auth_token": use_auth_token,
|
||||||
}
|
}
|
||||||
@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
mirror=mirror,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|||||||
@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
_fast_init = kwargs.pop("_fast_init", True)
|
_fast_init = kwargs.pop("_fast_init", True)
|
||||||
@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# message.
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
"mirror": mirror,
|
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"use_auth_token": use_auth_token,
|
"use_auth_token": use_auth_token,
|
||||||
}
|
}
|
||||||
@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
mirror=mirror,
|
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends
|
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
|
||||||
from .configuration_rag import RagConfig
|
from .configuration_rag import RagConfig
|
||||||
from .tokenization_rag import RagTokenizer
|
from .tokenization_rag import RagTokenizer
|
||||||
|
|
||||||
@@ -111,22 +111,21 @@ class LegacyIndex(Index):
|
|||||||
self._index_initialized = False
|
self._index_initialized = False
|
||||||
|
|
||||||
def _resolve_path(self, index_path, filename):
|
def _resolve_path(self, index_path, filename):
|
||||||
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`."
|
is_local = os.path.isdir(index_path)
|
||||||
archive_file = os.path.join(index_path, filename)
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_archive_file = cached_path(archive_file)
|
resolved_archive_file = cached_file(index_path, filename)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
f"Can't load '{filename}'. Make sure that:\n\n"
|
||||||
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
|
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
|
||||||
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
|
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
|
||||||
)
|
)
|
||||||
raise EnvironmentError(msg)
|
raise EnvironmentError(msg)
|
||||||
if resolved_archive_file == archive_file:
|
if is_local:
|
||||||
logger.info(f"loading file {archive_file}")
|
logger.info(f"loading file {resolved_archive_file}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}")
|
logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
|
||||||
return resolved_archive_file
|
return resolved_archive_file
|
||||||
|
|
||||||
def _load_passages(self):
|
def _load_passages(self):
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
cached_path,
|
cached_file,
|
||||||
is_sacremoses_available,
|
is_sacremoses_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -681,24 +681,21 @@ class TransfoXLCorpus(object):
|
|||||||
Instantiate a pre-processed corpus.
|
Instantiate a pre-processed corpus.
|
||||||
"""
|
"""
|
||||||
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
|
|
||||||
else:
|
|
||||||
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
|
f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
|
||||||
f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
|
f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
|
||||||
f" was a path or url but couldn't find files {corpus_file} at this path or url."
|
f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
if resolved_corpus_file == corpus_file:
|
if is_local:
|
||||||
logger.info(f"loading corpus file {corpus_file}")
|
logger.info(f"loading corpus file {resolved_corpus_file}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}")
|
logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
|
||||||
|
|
||||||
# Instantiate tokenizer.
|
# Instantiate tokenizer.
|
||||||
corpus = cls(*inputs, **kwargs)
|
corpus = cls(*inputs, **kwargs)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from numpy import isin
|
from numpy import isin
|
||||||
|
|
||||||
|
from huggingface_hub.file_download import http_get
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..dynamic_module_utils import get_class_from_dynamic_module
|
from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||||
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
from ..feature_extraction_utils import PreTrainedFeatureExtractor
|
||||||
@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
|
|||||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||||
from ..tokenization_utils import PreTrainedTokenizer
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
from ..tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging
|
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_tf_available, is_torch_available, logging
|
||||||
from .audio_classification import AudioClassificationPipeline
|
from .audio_classification import AudioClassificationPipeline
|
||||||
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
|
||||||
from .base import (
|
from .base import (
|
||||||
|
|||||||
@@ -61,25 +61,16 @@ from .hub import (
|
|||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
cached_file,
|
cached_file,
|
||||||
cached_path,
|
|
||||||
default_cache_path,
|
default_cache_path,
|
||||||
define_sagemaker_information,
|
define_sagemaker_information,
|
||||||
filename_to_url,
|
|
||||||
get_cached_models,
|
get_cached_models,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_from_cache,
|
|
||||||
get_full_repo_name,
|
get_full_repo_name,
|
||||||
get_list_of_files,
|
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
http_get,
|
|
||||||
http_user_agent,
|
http_user_agent,
|
||||||
is_local_clone,
|
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
|
||||||
move_cache,
|
move_cache,
|
||||||
send_example_telemetry,
|
send_example_telemetry,
|
||||||
url_to_filename,
|
|
||||||
)
|
)
|
||||||
from .import_utils import (
|
from .import_utils import (
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||||
|
|||||||
@@ -14,44 +14,32 @@
|
|||||||
"""
|
"""
|
||||||
Hub utilities: utilities related to download and cache models
|
Hub utilities: utilities related to download and cache models
|
||||||
"""
|
"""
|
||||||
import copy
|
|
||||||
import fnmatch
|
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
import tarfile
|
|
||||||
import tempfile
|
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
|
||||||
from hashlib import sha256
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import BinaryIO, Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from zipfile import ZipFile, is_zipfile
|
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
import requests
|
import requests
|
||||||
from filelock import FileLock
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
CommitOperationAdd,
|
CommitOperationAdd,
|
||||||
HfFolder,
|
HfFolder,
|
||||||
create_commit,
|
create_commit,
|
||||||
create_repo,
|
create_repo,
|
||||||
hf_hub_download,
|
hf_hub_download,
|
||||||
list_repo_files,
|
hf_hub_url,
|
||||||
whoami,
|
whoami,
|
||||||
)
|
)
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
|
||||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from requests.models import Response
|
|
||||||
from transformers.utils.logging import tqdm
|
from transformers.utils.logging import tqdm
|
||||||
|
|
||||||
from . import __version__, logging
|
from . import __version__, logging
|
||||||
@@ -128,93 +116,6 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{
|
|||||||
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
|
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
|
||||||
|
|
||||||
|
|
||||||
def is_remote_url(url_or_filename):
|
|
||||||
parsed = urlparse(url_or_filename)
|
|
||||||
return parsed.scheme in ("http", "https")
|
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(
|
|
||||||
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
|
||||||
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
|
||||||
|
|
||||||
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
|
||||||
bandwidth costs).
|
|
||||||
|
|
||||||
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
|
||||||
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
|
||||||
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
|
||||||
can't ever be stale.
|
|
||||||
|
|
||||||
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
|
||||||
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
|
||||||
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
|
||||||
"""
|
|
||||||
if subfolder is not None:
|
|
||||||
filename = f"{subfolder}/{filename}"
|
|
||||||
|
|
||||||
if mirror:
|
|
||||||
if mirror in ["tuna", "bfsu"]:
|
|
||||||
raise ValueError("The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument.")
|
|
||||||
legacy_format = "/" not in model_id
|
|
||||||
if legacy_format:
|
|
||||||
return f"{mirror}/{model_id}-{filename}"
|
|
||||||
else:
|
|
||||||
return f"{mirror}/{model_id}/{filename}"
|
|
||||||
|
|
||||||
if revision is None:
|
|
||||||
revision = "main"
|
|
||||||
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
|
||||||
|
|
||||||
|
|
||||||
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
|
||||||
"""
|
|
||||||
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
|
||||||
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
|
||||||
identify it as a HDF5 file (see
|
|
||||||
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
|
||||||
"""
|
|
||||||
url_bytes = url.encode("utf-8")
|
|
||||||
filename = sha256(url_bytes).hexdigest()
|
|
||||||
|
|
||||||
if etag:
|
|
||||||
etag_bytes = etag.encode("utf-8")
|
|
||||||
filename += "." + sha256(etag_bytes).hexdigest()
|
|
||||||
|
|
||||||
if url.endswith(".h5"):
|
|
||||||
filename += ".h5"
|
|
||||||
|
|
||||||
return filename
|
|
||||||
|
|
||||||
|
|
||||||
def filename_to_url(filename, cache_dir=None):
|
|
||||||
"""
|
|
||||||
Return the url and etag (which may be `None`) stored for *filename*. Raise `EnvironmentError` if *filename* or its
|
|
||||||
stored metadata do not exist.
|
|
||||||
"""
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
|
||||||
if not os.path.exists(cache_path):
|
|
||||||
raise EnvironmentError(f"file {cache_path} not found")
|
|
||||||
|
|
||||||
meta_path = cache_path + ".json"
|
|
||||||
if not os.path.exists(meta_path):
|
|
||||||
raise EnvironmentError(f"file {meta_path} not found")
|
|
||||||
|
|
||||||
with open(meta_path, encoding="utf-8") as meta_file:
|
|
||||||
metadata = json.load(meta_file)
|
|
||||||
url = metadata["url"]
|
|
||||||
etag = metadata["etag"]
|
|
||||||
|
|
||||||
return url, etag
|
|
||||||
|
|
||||||
|
|
||||||
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
||||||
"""
|
"""
|
||||||
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
|
||||||
@@ -248,108 +149,6 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
|||||||
return cached_models
|
return cached_models
|
||||||
|
|
||||||
|
|
||||||
def cached_path(
|
|
||||||
url_or_filename,
|
|
||||||
cache_dir=None,
|
|
||||||
force_download=False,
|
|
||||||
proxies=None,
|
|
||||||
resume_download=False,
|
|
||||||
user_agent: Union[Dict, str, None] = None,
|
|
||||||
extract_compressed_file=False,
|
|
||||||
force_extract=False,
|
|
||||||
use_auth_token: Union[bool, str, None] = None,
|
|
||||||
local_files_only=False,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
|
|
||||||
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
|
|
||||||
then return the path
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
|
||||||
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
|
||||||
resume_download: if True, resume the download if incompletely received file is found.
|
|
||||||
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
|
||||||
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
|
|
||||||
will get token from ~/.huggingface.
|
|
||||||
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
|
||||||
file in a folder along the archive.
|
|
||||||
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
|
||||||
re-extract the archive and override the folder where it was extracted.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
||||||
"""
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
|
||||||
if isinstance(url_or_filename, Path):
|
|
||||||
url_or_filename = str(url_or_filename)
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
if is_offline_mode() and not local_files_only:
|
|
||||||
logger.info("Offline mode: forcing local_files_only=True")
|
|
||||||
local_files_only = True
|
|
||||||
|
|
||||||
if is_remote_url(url_or_filename):
|
|
||||||
# URL, so get it from the cache (downloading if necessary)
|
|
||||||
output_path = get_from_cache(
|
|
||||||
url_or_filename,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_download=resume_download,
|
|
||||||
user_agent=user_agent,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
elif os.path.exists(url_or_filename):
|
|
||||||
# File, and it exists.
|
|
||||||
output_path = url_or_filename
|
|
||||||
elif urlparse(url_or_filename).scheme == "":
|
|
||||||
# File, but it doesn't exist.
|
|
||||||
raise EnvironmentError(f"file {url_or_filename} not found")
|
|
||||||
else:
|
|
||||||
# Something unknown
|
|
||||||
raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
|
|
||||||
|
|
||||||
if extract_compressed_file:
|
|
||||||
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
# Path where we extract compressed archives
|
|
||||||
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
|
||||||
output_dir, output_file = os.path.split(output_path)
|
|
||||||
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
|
||||||
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
|
||||||
|
|
||||||
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
|
||||||
return output_path_extracted
|
|
||||||
|
|
||||||
# Prevent parallel extractions
|
|
||||||
lock_path = output_path + ".lock"
|
|
||||||
with FileLock(lock_path):
|
|
||||||
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
|
||||||
os.makedirs(output_path_extracted)
|
|
||||||
if is_zipfile(output_path):
|
|
||||||
with ZipFile(output_path, "r") as zip_file:
|
|
||||||
zip_file.extractall(output_path_extracted)
|
|
||||||
zip_file.close()
|
|
||||||
elif tarfile.is_tarfile(output_path):
|
|
||||||
tar_file = tarfile.open(output_path)
|
|
||||||
tar_file.extractall(output_path_extracted)
|
|
||||||
tar_file.close()
|
|
||||||
else:
|
|
||||||
raise EnvironmentError(f"Archive format of {output_path} could not be identified")
|
|
||||||
|
|
||||||
return output_path_extracted
|
|
||||||
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def define_sagemaker_information():
|
def define_sagemaker_information():
|
||||||
try:
|
try:
|
||||||
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
|
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
|
||||||
@@ -399,234 +198,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
def _raise_for_status(response: Response):
|
|
||||||
"""
|
|
||||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
|
||||||
"""
|
|
||||||
if "X-Error-Code" in response.headers:
|
|
||||||
error_code = response.headers["X-Error-Code"]
|
|
||||||
if error_code == "RepoNotFound":
|
|
||||||
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
|
|
||||||
elif error_code == "EntryNotFound":
|
|
||||||
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
|
|
||||||
elif error_code == "RevisionNotFound":
|
|
||||||
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
|
|
||||||
|
|
||||||
if response.status_code == 401:
|
|
||||||
# The repo was not found and the user is not Authenticated
|
|
||||||
raise RepositoryNotFoundError(
|
|
||||||
f"401 Client Error: Repository not found for url: {response.url}. "
|
|
||||||
"If the repo is private, make sure you are authenticated."
|
|
||||||
)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
|
|
||||||
def http_get(
|
|
||||||
url: str,
|
|
||||||
temp_file: BinaryIO,
|
|
||||||
proxies=None,
|
|
||||||
resume_size=0,
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
file_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Download remote file. Do not gobble up errors.
|
|
||||||
"""
|
|
||||||
headers = copy.deepcopy(headers)
|
|
||||||
if resume_size > 0:
|
|
||||||
headers["Range"] = f"bytes={resume_size}-"
|
|
||||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
|
||||||
_raise_for_status(r)
|
|
||||||
content_length = r.headers.get("Content-Length")
|
|
||||||
total = resume_size + int(content_length) if content_length is not None else None
|
|
||||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
|
||||||
# and can be set using `utils.logging.enable/disable_progress_bar()`
|
|
||||||
progress = tqdm(
|
|
||||||
unit="B",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
total=total,
|
|
||||||
initial=resume_size,
|
|
||||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
|
||||||
)
|
|
||||||
for chunk in r.iter_content(chunk_size=1024):
|
|
||||||
if chunk: # filter out keep-alive new chunks
|
|
||||||
progress.update(len(chunk))
|
|
||||||
temp_file.write(chunk)
|
|
||||||
progress.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(
|
|
||||||
url: str,
|
|
||||||
cache_dir=None,
|
|
||||||
force_download=False,
|
|
||||||
proxies=None,
|
|
||||||
etag_timeout=10,
|
|
||||||
resume_download=False,
|
|
||||||
user_agent: Union[Dict, str, None] = None,
|
|
||||||
use_auth_token: Union[bool, str, None] = None,
|
|
||||||
local_files_only=False,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
|
|
||||||
path to the cached file.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
||||||
"""
|
|
||||||
if cache_dir is None:
|
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
headers = {"user-agent": http_user_agent(user_agent)}
|
|
||||||
if isinstance(use_auth_token, str):
|
|
||||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
|
||||||
elif use_auth_token:
|
|
||||||
token = HfFolder.get_token()
|
|
||||||
if token is None:
|
|
||||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
|
||||||
headers["authorization"] = f"Bearer {token}"
|
|
||||||
|
|
||||||
url_to_download = url
|
|
||||||
etag = None
|
|
||||||
if not local_files_only:
|
|
||||||
try:
|
|
||||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
|
||||||
_raise_for_status(r)
|
|
||||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
|
||||||
# We favor a custom header indicating the etag of the linked resource, and
|
|
||||||
# we fallback to the regular etag header.
|
|
||||||
# If we don't have any of those, raise an error.
|
|
||||||
if etag is None:
|
|
||||||
raise OSError(
|
|
||||||
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
|
||||||
)
|
|
||||||
# In case of a redirect,
|
|
||||||
# save an extra redirect on the request.get call,
|
|
||||||
# and ensure we download the exact atomic version even if it changed
|
|
||||||
# between the HEAD and the GET (unlikely, but hey).
|
|
||||||
if 300 <= r.status_code <= 399:
|
|
||||||
url_to_download = r.headers["Location"]
|
|
||||||
except (
|
|
||||||
requests.exceptions.SSLError,
|
|
||||||
requests.exceptions.ProxyError,
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
EntryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
):
|
|
||||||
# Actually raise for those subclasses of ConnectionError
|
|
||||||
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
|
|
||||||
raise
|
|
||||||
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
|
||||||
# Otherwise, our Internet connection is down.
|
|
||||||
# etag is None
|
|
||||||
pass
|
|
||||||
|
|
||||||
filename = url_to_filename(url, etag)
|
|
||||||
|
|
||||||
# get cache path to put the file
|
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
|
||||||
|
|
||||||
# etag is None == we don't have a connection or we passed local_files_only.
|
|
||||||
# try to get the last downloaded one
|
|
||||||
if etag is None:
|
|
||||||
if os.path.exists(cache_path):
|
|
||||||
return cache_path
|
|
||||||
else:
|
|
||||||
matching_files = [
|
|
||||||
file
|
|
||||||
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
|
||||||
if not file.endswith(".json") and not file.endswith(".lock")
|
|
||||||
]
|
|
||||||
if len(matching_files) > 0:
|
|
||||||
return os.path.join(cache_dir, matching_files[-1])
|
|
||||||
else:
|
|
||||||
# If files cannot be found and local_files_only=True,
|
|
||||||
# the models might've been found if local_files_only=False
|
|
||||||
# Notify the user about that
|
|
||||||
if local_files_only:
|
|
||||||
fname = url.split("/")[-1]
|
|
||||||
raise EntryNotFoundError(
|
|
||||||
f"Cannot find the requested file ({fname}) in the cached path and outgoing traffic has been"
|
|
||||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
|
||||||
" to False."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Connection error, and we cannot find the requested files in the cached path."
|
|
||||||
" Please try again or make sure your Internet connection is on."
|
|
||||||
)
|
|
||||||
|
|
||||||
# From now on, etag is not None.
|
|
||||||
if os.path.exists(cache_path) and not force_download:
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
# Prevent parallel downloads of the same file with a lock.
|
|
||||||
lock_path = cache_path + ".lock"
|
|
||||||
with FileLock(lock_path):
|
|
||||||
|
|
||||||
# If the download just completed while the lock was activated.
|
|
||||||
if os.path.exists(cache_path) and not force_download:
|
|
||||||
# Even if returning early like here, the lock will be released.
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
if resume_download:
|
|
||||||
incomplete_path = cache_path + ".incomplete"
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _resumable_file_manager() -> "io.BufferedWriter":
|
|
||||||
with open(incomplete_path, "ab") as f:
|
|
||||||
yield f
|
|
||||||
|
|
||||||
temp_file_manager = _resumable_file_manager
|
|
||||||
if os.path.exists(incomplete_path):
|
|
||||||
resume_size = os.stat(incomplete_path).st_size
|
|
||||||
else:
|
|
||||||
resume_size = 0
|
|
||||||
else:
|
|
||||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
|
||||||
resume_size = 0
|
|
||||||
|
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
|
||||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
|
||||||
with temp_file_manager() as temp_file:
|
|
||||||
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
|
|
||||||
|
|
||||||
# The url_to_download might be messy, so we extract the file name from the original url.
|
|
||||||
file_name = url.split("/")[-1]
|
|
||||||
http_get(
|
|
||||||
url_to_download,
|
|
||||||
temp_file,
|
|
||||||
proxies=proxies,
|
|
||||||
resume_size=resume_size,
|
|
||||||
headers=headers,
|
|
||||||
file_name=file_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"storing {url} in cache at {cache_path}")
|
|
||||||
os.replace(temp_file.name, cache_path)
|
|
||||||
|
|
||||||
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
|
|
||||||
umask = os.umask(0o666)
|
|
||||||
os.umask(umask)
|
|
||||||
os.chmod(cache_path, 0o666 & ~umask)
|
|
||||||
|
|
||||||
logger.info(f"creating metadata file for {cache_path}")
|
|
||||||
meta = {"url": url, "etag": etag}
|
|
||||||
meta_path = cache_path + ".json"
|
|
||||||
with open(meta_path, "w") as meta_file:
|
|
||||||
json.dump(meta, meta_file)
|
|
||||||
|
|
||||||
return cache_path
|
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
|
||||||
"""
|
"""
|
||||||
Explores the cache to return the latest cached file for a given revision.
|
Explores the cache to return the latest cached file for a given revision.
|
||||||
@@ -919,7 +490,6 @@ def has_file(
|
|||||||
path_or_repo: Union[str, os.PathLike],
|
path_or_repo: Union[str, os.PathLike],
|
||||||
filename: str,
|
filename: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
mirror: Optional[str] = None,
|
|
||||||
proxies: Optional[Dict[str, str]] = None,
|
proxies: Optional[Dict[str, str]] = None,
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
):
|
):
|
||||||
@@ -936,7 +506,7 @@ def has_file(
|
|||||||
if os.path.isdir(path_or_repo):
|
if os.path.isdir(path_or_repo):
|
||||||
return os.path.isfile(os.path.join(path_or_repo, filename))
|
return os.path.isfile(os.path.join(path_or_repo, filename))
|
||||||
|
|
||||||
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
|
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
|
||||||
|
|
||||||
headers = {"user-agent": http_user_agent()}
|
headers = {"user-agent": http_user_agent()}
|
||||||
if isinstance(use_auth_token, str):
|
if isinstance(use_auth_token, str):
|
||||||
@@ -965,89 +535,6 @@ def has_file(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_list_of_files(
|
|
||||||
path_or_repo: Union[str, os.PathLike],
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
|
||||||
local_files_only: bool = False,
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Gets the list of files inside `path_or_repo`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path_or_repo (`str` or `os.PathLike`):
|
|
||||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
||||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
||||||
identifier allowed by git.
|
|
||||||
use_auth_token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
|
||||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to only rely on local files and not to attempt to download any files.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This API is not optimized, so calling it a lot may result in connection errors.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`List[str]`: The list of files available in `path_or_repo`.
|
|
||||||
"""
|
|
||||||
path_or_repo = str(path_or_repo)
|
|
||||||
# If path_or_repo is a folder, we just return what is inside (subdirectories included).
|
|
||||||
if os.path.isdir(path_or_repo):
|
|
||||||
list_of_files = []
|
|
||||||
for path, dir_names, file_names in os.walk(path_or_repo):
|
|
||||||
list_of_files.extend([os.path.join(path, f) for f in file_names])
|
|
||||||
return list_of_files
|
|
||||||
|
|
||||||
# Can't grab the files if we are on offline mode.
|
|
||||||
if is_offline_mode() or local_files_only:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Otherwise we grab the token and use the list_repo_files method.
|
|
||||||
if isinstance(use_auth_token, str):
|
|
||||||
token = use_auth_token
|
|
||||||
elif use_auth_token is True:
|
|
||||||
token = HfFolder.get_token()
|
|
||||||
else:
|
|
||||||
token = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return list_repo_files(path_or_repo, revision=revision, token=token)
|
|
||||||
except HTTPError as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def is_local_clone(repo_path, repo_url):
|
|
||||||
"""
|
|
||||||
Checks if the folder in `repo_path` is a local clone of `repo_url`.
|
|
||||||
"""
|
|
||||||
# First double-check that `repo_path` is a git repo
|
|
||||||
if not os.path.exists(os.path.join(repo_path, ".git")):
|
|
||||||
return False
|
|
||||||
test_git = subprocess.run("git branch".split(), cwd=repo_path)
|
|
||||||
if test_git.returncode != 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Then look at its remotes
|
|
||||||
remotes = subprocess.run(
|
|
||||||
"git remote -v".split(),
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
check=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
cwd=repo_path,
|
|
||||||
).stdout
|
|
||||||
|
|
||||||
return repo_url in remotes.split()
|
|
||||||
|
|
||||||
|
|
||||||
class PushToHubMixin:
|
class PushToHubMixin:
|
||||||
"""
|
"""
|
||||||
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
A Mixin containing the functionality to push a model or tokenizer to the hub.
|
||||||
@@ -1310,7 +797,6 @@ def get_checkpoint_shard_files(
|
|||||||
use_auth_token=None,
|
use_auth_token=None,
|
||||||
user_agent=None,
|
user_agent=None,
|
||||||
revision=None,
|
revision=None,
|
||||||
mirror=None,
|
|
||||||
subfolder="",
|
subfolder="",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1343,18 +829,11 @@ def get_checkpoint_shard_files(
|
|||||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||||
cached_filenames = []
|
cached_filenames = []
|
||||||
for shard_filename in shard_filenames:
|
for shard_filename in shard_filenames:
|
||||||
shard_url = hf_bucket_url(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
filename=shard_filename,
|
|
||||||
revision=revision,
|
|
||||||
mirror=mirror,
|
|
||||||
subfolder=subfolder if len(subfolder) > 0 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL
|
# Load from URL
|
||||||
cached_filename = cached_path(
|
cached_filename = cached_file(
|
||||||
shard_url,
|
pretrained_model_name_or_path,
|
||||||
|
shard_filename,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
@@ -1362,6 +841,8 @@ def get_checkpoint_shard_files(
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||||
# we don't have to catch them here.
|
# we don't have to catch them here.
|
||||||
|
|||||||
@@ -26,20 +26,13 @@ import transformers
|
|||||||
from transformers import * # noqa F406
|
from transformers import * # noqa F406
|
||||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
CONFIG_NAME,
|
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
ContextManagers,
|
ContextManagers,
|
||||||
EntryNotFoundError,
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
filename_to_url,
|
|
||||||
find_labels,
|
find_labels,
|
||||||
get_file_from_repo,
|
get_file_from_repo,
|
||||||
get_from_cache,
|
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class GetFromCacheTests(unittest.TestCase):
|
class GetFromCacheTests(unittest.TestCase):
|
||||||
def test_bogus_url(self):
|
|
||||||
# This lets us simulate no connection
|
|
||||||
# as the error raised is the same
|
|
||||||
# `ConnectionError`
|
|
||||||
url = "https://bogus"
|
|
||||||
with self.assertRaisesRegex(ValueError, "Connection error"):
|
|
||||||
_ = get_from_cache(url)
|
|
||||||
|
|
||||||
def test_file_not_found(self):
|
|
||||||
# Valid revision (None) but missing file.
|
|
||||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
|
||||||
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
|
||||||
_ = get_from_cache(url)
|
|
||||||
|
|
||||||
def test_model_not_found_not_authenticated(self):
|
|
||||||
# Invalid model id.
|
|
||||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
|
||||||
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
|
|
||||||
_ = get_from_cache(url)
|
|
||||||
|
|
||||||
@unittest.skip("No authentication when testing against prod")
|
|
||||||
def test_model_not_found_authenticated(self):
|
|
||||||
# Invalid model id.
|
|
||||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
|
||||||
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
|
||||||
_ = get_from_cache(url, use_auth_token="hf_sometoken")
|
|
||||||
# ^ TODO - if we decide to unskip this: use a real / functional token
|
|
||||||
|
|
||||||
def test_revision_not_found(self):
|
|
||||||
# Valid file but missing revision
|
|
||||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
|
||||||
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
|
|
||||||
_ = get_from_cache(url)
|
|
||||||
|
|
||||||
def test_standard_object(self):
|
|
||||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
|
||||||
filepath = get_from_cache(url, force_download=True)
|
|
||||||
metadata = filename_to_url(filepath)
|
|
||||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
|
||||||
|
|
||||||
def test_standard_object_rev(self):
|
|
||||||
# Same object, but different revision
|
|
||||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
|
||||||
filepath = get_from_cache(url, force_download=True)
|
|
||||||
metadata = filename_to_url(filepath)
|
|
||||||
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
|
||||||
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
|
||||||
|
|
||||||
def test_lfs_object(self):
|
|
||||||
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
|
||||||
filepath = get_from_cache(url, force_download=True)
|
|
||||||
metadata = filename_to_url(filepath)
|
|
||||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
|
||||||
|
|
||||||
def test_has_file(self):
|
def test_has_file(self):
|
||||||
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
||||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||||
|
|||||||
@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [
|
|||||||
"absl", # External module
|
"absl", # External module
|
||||||
"add_end_docstrings", # Internal, should never have been in the main init.
|
"add_end_docstrings", # Internal, should never have been in the main init.
|
||||||
"add_start_docstrings", # Internal, should never have been in the main init.
|
"add_start_docstrings", # Internal, should never have been in the main init.
|
||||||
"cached_path", # Internal used for downloading models.
|
|
||||||
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
|
||||||
"logger", # Internal logger
|
"logger", # Internal logger
|
||||||
"logging", # External module
|
"logging", # External module
|
||||||
|
|||||||
Reference in New Issue
Block a user