[3/N] Use pyupgrade --py39-plus to improve code (#36936)
Use pyupgrade --py39-plus to improve code Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -90,7 +90,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
|
||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of relative imports in the module.
|
||||
`list[str]`: The list of relative imports in the module.
|
||||
"""
|
||||
with open(module_file, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@@ -112,7 +112,7 @@ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]
|
||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
||||
`list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
||||
of module files a given module needs.
|
||||
"""
|
||||
no_change = False
|
||||
@@ -144,7 +144,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
||||
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all packages required to use the input module.
|
||||
`list[str]`: The list of all packages required to use the input module.
|
||||
"""
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
@@ -175,7 +175,7 @@ def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
||||
filename (`str` or `os.PathLike`): The module file to check.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of relative imports in the file.
|
||||
`list[str]`: The list of relative imports in the file.
|
||||
"""
|
||||
imports = get_imports(filename)
|
||||
missing_packages = []
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -16,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import FrozenSet
|
||||
|
||||
from huggingface_hub import get_full_repo_name # for backward compatibility
|
||||
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
|
||||
@@ -300,8 +298,8 @@ def check_min_version(min_version):
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_available_devices() -> FrozenSet[str]:
|
||||
@lru_cache
|
||||
def get_available_devices() -> frozenset[str]:
|
||||
"""
|
||||
Returns a frozenset of devices available for the current PyTorch installation.
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,7 +16,8 @@
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -75,9 +75,9 @@ def verify_out_features_out_indices(
|
||||
|
||||
|
||||
def _align_output_features_output_indices(
|
||||
out_features: Optional[List[str]],
|
||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
||||
stage_names: List[str],
|
||||
out_features: Optional[list[str]],
|
||||
out_indices: Optional[Union[list[int], tuple[int]]],
|
||||
stage_names: list[str],
|
||||
):
|
||||
"""
|
||||
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
|
||||
@@ -106,10 +106,10 @@ def _align_output_features_output_indices(
|
||||
|
||||
|
||||
def get_aligned_output_features_output_indices(
|
||||
out_features: Optional[List[str]],
|
||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
||||
stage_names: List[str],
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
out_features: Optional[list[str]],
|
||||
out_indices: Optional[Union[list[int], tuple[int]]],
|
||||
stage_names: list[str],
|
||||
) -> tuple[list[str], list[int]]:
|
||||
"""
|
||||
Get the `out_features` and `out_indices` so that they are aligned.
|
||||
|
||||
@@ -198,7 +198,7 @@ class BackboneMixin:
|
||||
return self._out_features
|
||||
|
||||
@out_features.setter
|
||||
def out_features(self, out_features: List[str]):
|
||||
def out_features(self, out_features: list[str]):
|
||||
"""
|
||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||
"""
|
||||
@@ -211,7 +211,7 @@ class BackboneMixin:
|
||||
return self._out_indices
|
||||
|
||||
@out_indices.setter
|
||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
||||
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
|
||||
"""
|
||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||
"""
|
||||
@@ -264,7 +264,7 @@ class BackboneConfigMixin:
|
||||
return self._out_features
|
||||
|
||||
@out_features.setter
|
||||
def out_features(self, out_features: List[str]):
|
||||
def out_features(self, out_features: list[str]):
|
||||
"""
|
||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||
"""
|
||||
@@ -277,7 +277,7 @@ class BackboneConfigMixin:
|
||||
return self._out_indices
|
||||
|
||||
@out_indices.setter
|
||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
||||
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
|
||||
"""
|
||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ import types
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -71,7 +71,7 @@ class DocstringParsingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
def _get_json_schema_type(param_type: str) -> dict[str, str]:
|
||||
type_mapping = {
|
||||
int: {"type": "integer"},
|
||||
float: {"type": "number"},
|
||||
@@ -87,7 +87,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
||||
return type_mapping.get(param_type, {"type": "object"})
|
||||
|
||||
|
||||
def _parse_type_hint(hint: str) -> Dict:
|
||||
def _parse_type_hint(hint: str) -> dict:
|
||||
origin = get_origin(hint)
|
||||
args = get_args(hint)
|
||||
|
||||
@@ -152,7 +152,7 @@ def _parse_type_hint(hint: str) -> Dict:
|
||||
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
required = []
|
||||
@@ -173,7 +173,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
||||
return schema
|
||||
|
||||
|
||||
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
|
||||
def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
|
||||
"""
|
||||
Parses a Google-style docstring to extract the function description,
|
||||
argument descriptions, and return description.
|
||||
@@ -206,7 +206,7 @@ def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Option
|
||||
return description, args_dict, returns
|
||||
|
||||
|
||||
def get_json_schema(func: Callable) -> Dict:
|
||||
def get_json_schema(func: Callable) -> dict:
|
||||
"""
|
||||
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||
@@ -398,7 +398,7 @@ def _compile_jinja_template(chat_template):
|
||||
return self._rendered_blocks or self._generation_indices
|
||||
|
||||
@contextmanager
|
||||
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
|
||||
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
|
||||
try:
|
||||
if self.is_active():
|
||||
raise ValueError("AssistantTracker should not be reused before closed")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -24,7 +23,7 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@@ -78,9 +77,9 @@ _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE
|
||||
|
||||
|
||||
def _generate_supported_model_class_names(
|
||||
model_name: Type[PretrainedConfig],
|
||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
||||
) -> List[str]:
|
||||
model_name: type[PretrainedConfig],
|
||||
supported_tasks: Optional[Union[str, list[str]]] = None,
|
||||
) -> list[str]:
|
||||
task_mapping = {
|
||||
"default": MODEL_MAPPING_NAMES,
|
||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||
@@ -590,7 +589,7 @@ def operator_getitem(a, b):
|
||||
return operator.getitem(a, b)
|
||||
|
||||
|
||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
_MANUAL_META_OVERRIDES: dict[Callable, Callable] = {
|
||||
torch.nn.Embedding: torch_nn_embedding,
|
||||
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||
@@ -716,7 +715,7 @@ class HFCacheProxy(HFProxy):
|
||||
Proxy that represents an instance of `transformers.cache_utils.Cache`.
|
||||
"""
|
||||
|
||||
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
|
||||
def install_orig_cache_cls(self, orig_cache_cls: type[Cache]):
|
||||
self._orig_cache_cls = orig_cache_cls
|
||||
|
||||
@property
|
||||
@@ -770,8 +769,8 @@ class HFProxyableClassMeta(type):
|
||||
def __new__(
|
||||
cls,
|
||||
name: str,
|
||||
bases: Tuple[Type, ...],
|
||||
attrs: Dict[str, Any],
|
||||
bases: tuple[type, ...],
|
||||
attrs: dict[str, Any],
|
||||
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||
):
|
||||
cls = super().__new__(cls, name, bases, attrs)
|
||||
@@ -794,7 +793,7 @@ class HFProxyableClassMeta(type):
|
||||
return cls
|
||||
|
||||
|
||||
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
|
||||
def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
|
||||
"""
|
||||
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
|
||||
"""
|
||||
@@ -813,7 +812,7 @@ def _proxies_to_metas(v):
|
||||
return v
|
||||
|
||||
|
||||
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||
def create_cache_proxy_factory_fn(orig_cache_cls: type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||
global _CURRENT_TRACER
|
||||
if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||
@@ -849,7 +848,7 @@ ProxyableStaticCache = HFProxyableClassMeta(
|
||||
)
|
||||
|
||||
|
||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[list[int]] = None):
|
||||
if forbidden_values is None:
|
||||
forbidden_values = []
|
||||
value = random.randint(low, high)
|
||||
@@ -899,8 +898,8 @@ class HFTracer(Tracer):
|
||||
)
|
||||
|
||||
def _generate_dummy_input(
|
||||
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
self, model: "PreTrainedModel", input_name: str, shape: list[int], input_names: list[str]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Generates dummy input for model inference recording."""
|
||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||
# from pickle, or from the "__class__" attribute in the general case.
|
||||
@@ -1181,7 +1180,7 @@ class HFTracer(Tracer):
|
||||
return attr_val
|
||||
|
||||
# Needed for PyTorch 1.13+
|
||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
|
||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||
|
||||
def call_module(self, m, forward, args, kwargs):
|
||||
@@ -1233,8 +1232,8 @@ class HFTracer(Tracer):
|
||||
def trace(
|
||||
self,
|
||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||
concrete_args: Optional[Dict[str, Any]] = None,
|
||||
dummy_inputs: Optional[Dict[str, Any]] = None,
|
||||
concrete_args: Optional[dict[str, Any]] = None,
|
||||
dummy_inputs: Optional[dict[str, Any]] = None,
|
||||
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
|
||||
) -> Graph:
|
||||
"""
|
||||
@@ -1422,7 +1421,7 @@ class HFTracer(Tracer):
|
||||
return attribute
|
||||
|
||||
|
||||
def get_concrete_args(model: nn.Module, input_names: List[str]):
|
||||
def get_concrete_args(model: nn.Module, input_names: list[str]):
|
||||
sig = inspect.signature(model.forward)
|
||||
|
||||
if not (set(input_names) <= set(sig.parameters.keys())):
|
||||
@@ -1450,9 +1449,9 @@ def check_if_model_is_supported(model: "PreTrainedModel"):
|
||||
|
||||
def symbolic_trace(
|
||||
model: "PreTrainedModel",
|
||||
input_names: Optional[List[str]] = None,
|
||||
input_names: Optional[list[str]] = None,
|
||||
disable_check: bool = False,
|
||||
tracer_cls: Type[HFTracer] = HFTracer,
|
||||
tracer_cls: type[HFTracer] = HFTracer,
|
||||
) -> GraphModule:
|
||||
"""
|
||||
Performs symbolic tracing on the model.
|
||||
|
||||
@@ -21,12 +21,12 @@ import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import OrderedDict, UserDict
|
||||
from collections.abc import MutableMapping
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
||||
from typing import Any, ContextManager, Optional, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -465,7 +465,7 @@ class ModelOutput(OrderedDict):
|
||||
args = tuple(getattr(self, field.name) for field in fields(self))
|
||||
return callable, args, *remaining
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
def to_tuple(self) -> tuple[Any]:
|
||||
"""
|
||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||
"""
|
||||
@@ -475,7 +475,7 @@ class ModelOutput(OrderedDict):
|
||||
if is_torch_available():
|
||||
import torch.utils._pytree as _torch_pytree
|
||||
|
||||
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
||||
def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], "_torch_pytree.Context"]:
|
||||
return list(output.values()), list(output.keys())
|
||||
|
||||
def _model_output_unflatten(
|
||||
@@ -542,7 +542,7 @@ class ContextManagers:
|
||||
in the `fastcore` library.
|
||||
"""
|
||||
|
||||
def __init__(self, context_managers: List[ContextManager]):
|
||||
def __init__(self, context_managers: list[ContextManager]):
|
||||
self.context_managers = context_managers
|
||||
self.stack = ExitStack()
|
||||
|
||||
@@ -883,7 +883,7 @@ class LossKwargs(TypedDict, total=False):
|
||||
num_items_in_batch: Optional[int]
|
||||
|
||||
|
||||
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
||||
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
|
||||
"""Checks whether a config dict is a timm config dict."""
|
||||
return "pretrained_cfg" in config_dict
|
||||
|
||||
@@ -903,13 +903,13 @@ def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
|
||||
|
||||
# pretrained_model_path is a file
|
||||
if is_file and pretrained_model_path.endswith(".json"):
|
||||
with open(pretrained_model_path, "r") as f:
|
||||
with open(pretrained_model_path) as f:
|
||||
config_dict = json.load(f)
|
||||
return is_timm_config_dict(config_dict)
|
||||
|
||||
# pretrained_model_path is a directory with a config.json
|
||||
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
|
||||
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
|
||||
with open(os.path.join(pretrained_model_path, "config.json")) as f:
|
||||
config_dict = json.load(f)
|
||||
return is_timm_config_dict(config_dict)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import tempfile
|
||||
import warnings
|
||||
from concurrent import futures
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -168,7 +168,7 @@ def define_sagemaker_information():
|
||||
return sagemaker_object
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
def http_user_agent(user_agent: Union[dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
@@ -270,17 +270,17 @@ def cached_file(
|
||||
|
||||
def cached_files(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filenames: List[str],
|
||||
filenames: list[str],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
proxies: Optional[dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
user_agent: Optional[Union[str, dict[str, str]]] = None,
|
||||
_raise_exceptions_for_gated_repo: bool = True,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
@@ -378,7 +378,7 @@ def cached_files(
|
||||
if not os.path.isfile(resolved_file):
|
||||
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
|
||||
revision_ = "main" if revision is None else revision
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
|
||||
)
|
||||
@@ -410,7 +410,7 @@ def cached_files(
|
||||
elif not _raise_exceptions_for_missing_entries:
|
||||
file_counter += 1
|
||||
else:
|
||||
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.")
|
||||
raise OSError(f"Could not locate {filename} inside {path_or_repo_id}.")
|
||||
|
||||
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
|
||||
if file_counter == len(full_filenames):
|
||||
@@ -453,14 +453,14 @@ def cached_files(
|
||||
except Exception as e:
|
||||
# We cannot recover from them
|
||||
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
|
||||
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
|
||||
"`token=<your_token>`"
|
||||
) from e
|
||||
elif isinstance(e, RevisionNotFoundError):
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||
@@ -478,7 +478,7 @@ def cached_files(
|
||||
if isinstance(e, GatedRepoError):
|
||||
if not _raise_exceptions_for_gated_repo:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
"You are trying to access a gated repo.\nMake sure to have access to it at "
|
||||
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
|
||||
) from e
|
||||
@@ -488,7 +488,7 @@ def cached_files(
|
||||
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
|
||||
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
|
||||
elif _raise_exceptions_for_missing_entries:
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
|
||||
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
@@ -498,9 +498,7 @@ def cached_files(
|
||||
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
|
||||
)
|
||||
raise OSError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}")
|
||||
|
||||
resolved_files = [
|
||||
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
|
||||
@@ -632,7 +630,7 @@ def has_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
revision: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
proxies: Optional[dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
*,
|
||||
local_files_only: bool = False,
|
||||
@@ -707,19 +705,17 @@ def has_file(
|
||||
return True
|
||||
except GatedRepoError as e:
|
||||
logger.error(e)
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"{path_or_repo} is a gated repository. Make sure to request access at "
|
||||
f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
|
||||
"logging in with `huggingface-cli login` or by passing `token=<your_token>`."
|
||||
) from e
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.error(e)
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
|
||||
) from e
|
||||
raise OSError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") from e
|
||||
except RevisionNotFoundError as e:
|
||||
logger.error(e)
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||
) from e
|
||||
@@ -780,7 +776,7 @@ class PushToHubMixin:
|
||||
self,
|
||||
working_dir: Union[str, os.PathLike],
|
||||
repo_id: str,
|
||||
files_timestamps: Dict[str, float],
|
||||
files_timestamps: dict[str, float],
|
||||
commit_message: Optional[str] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
@@ -867,7 +863,7 @@ class PushToHubMixin:
|
||||
safe_serialization: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**deprecated_kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -1101,7 +1097,7 @@ def get_checkpoint_shard_files(
|
||||
if not os.path.isfile(index_filename):
|
||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
|
||||
with open(index_filename, "r") as f:
|
||||
with open(index_filename) as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||
@@ -1136,7 +1132,7 @@ def get_checkpoint_shard_files(
|
||||
|
||||
def create_and_tag_model_card(
|
||||
repo_id: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
token: Optional[str] = None,
|
||||
ignore_metadata_errors: bool = False,
|
||||
):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Optuna, Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@@ -31,7 +31,7 @@ def find_adapter_config_file(
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
proxies: Optional[dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: sentencepiece_model.proto
|
||||
"""Generated protocol buffer code."""
|
||||
|
||||
Reference in New Issue
Block a user