[3/N] Use pyupgrade --py39-plus to improve code (#36936)

Use pyupgrade --py39-plus to improve code

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-02 21:16:06 +08:00
committed by GitHub
parent 764ab0d46a
commit 32c12aaec3
14 changed files with 75 additions and 88 deletions

View File

@@ -90,7 +90,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
module_file (`str` or `os.PathLike`): The module file to inspect. module_file (`str` or `os.PathLike`): The module file to inspect.
Returns: Returns:
`List[str]`: The list of relative imports in the module. `list[str]`: The list of relative imports in the module.
""" """
with open(module_file, encoding="utf-8") as f: with open(module_file, encoding="utf-8") as f:
content = f.read() content = f.read()
@@ -112,7 +112,7 @@ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]
module_file (`str` or `os.PathLike`): The module file to inspect. module_file (`str` or `os.PathLike`): The module file to inspect.
Returns: Returns:
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list `list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
of module files a given module needs. of module files a given module needs.
""" """
no_change = False no_change = False
@@ -144,7 +144,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
filename (`str` or `os.PathLike`): The module file to inspect. filename (`str` or `os.PathLike`): The module file to inspect.
Returns: Returns:
`List[str]`: The list of all packages required to use the input module. `list[str]`: The list of all packages required to use the input module.
""" """
with open(filename, encoding="utf-8") as f: with open(filename, encoding="utf-8") as f:
content = f.read() content = f.read()
@@ -175,7 +175,7 @@ def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
filename (`str` or `os.PathLike`): The module file to check. filename (`str` or `os.PathLike`): The module file to check.
Returns: Returns:
`List[str]`: The list of relative imports in the file. `list[str]`: The list of relative imports in the file.
""" """
imports = get_imports(filename) imports = get_imports(filename)
missing_packages = [] missing_packages = []

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. # Copyright 2025 The HuggingFace Inc. team.
# All rights reserved. # All rights reserved.
# #

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# #
@@ -16,7 +15,6 @@
# limitations under the License. # limitations under the License.
from functools import lru_cache from functools import lru_cache
from typing import FrozenSet
from huggingface_hub import get_full_repo_name # for backward compatibility from huggingface_hub import get_full_repo_name # for backward compatibility
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
@@ -300,8 +298,8 @@ def check_min_version(min_version):
) )
@lru_cache() @lru_cache
def get_available_devices() -> FrozenSet[str]: def get_available_devices() -> frozenset[str]:
""" """
Returns a frozenset of devices available for the current PyTorch installation. Returns a frozenset of devices available for the current PyTorch installation.
""" """

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. # Copyright 2025 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. # Copyright 2023 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +16,8 @@
import enum import enum
import inspect import inspect
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -75,9 +75,9 @@ def verify_out_features_out_indices(
def _align_output_features_output_indices( def _align_output_features_output_indices(
out_features: Optional[List[str]], out_features: Optional[list[str]],
out_indices: Optional[Union[List[int], Tuple[int]]], out_indices: Optional[Union[list[int], tuple[int]]],
stage_names: List[str], stage_names: list[str],
): ):
""" """
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`. Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
@@ -106,10 +106,10 @@ def _align_output_features_output_indices(
def get_aligned_output_features_output_indices( def get_aligned_output_features_output_indices(
out_features: Optional[List[str]], out_features: Optional[list[str]],
out_indices: Optional[Union[List[int], Tuple[int]]], out_indices: Optional[Union[list[int], tuple[int]]],
stage_names: List[str], stage_names: list[str],
) -> Tuple[List[str], List[int]]: ) -> tuple[list[str], list[int]]:
""" """
Get the `out_features` and `out_indices` so that they are aligned. Get the `out_features` and `out_indices` so that they are aligned.
@@ -198,7 +198,7 @@ class BackboneMixin:
return self._out_features return self._out_features
@out_features.setter @out_features.setter
def out_features(self, out_features: List[str]): def out_features(self, out_features: list[str]):
""" """
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
""" """
@@ -211,7 +211,7 @@ class BackboneMixin:
return self._out_indices return self._out_indices
@out_indices.setter @out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]): def out_indices(self, out_indices: Union[tuple[int], list[int]]):
""" """
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
""" """
@@ -264,7 +264,7 @@ class BackboneConfigMixin:
return self._out_features return self._out_features
@out_features.setter @out_features.setter
def out_features(self, out_features: List[str]): def out_features(self, out_features: list[str]):
""" """
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
""" """
@@ -277,7 +277,7 @@ class BackboneConfigMixin:
return self._out_indices return self._out_indices
@out_indices.setter @out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]): def out_indices(self, out_indices: Union[tuple[int], list[int]]):
""" """
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
""" """

View File

@@ -19,7 +19,7 @@ import types
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from packaging import version from packaging import version
@@ -71,7 +71,7 @@ class DocstringParsingException(Exception):
pass pass
def _get_json_schema_type(param_type: str) -> Dict[str, str]: def _get_json_schema_type(param_type: str) -> dict[str, str]:
type_mapping = { type_mapping = {
int: {"type": "integer"}, int: {"type": "integer"},
float: {"type": "number"}, float: {"type": "number"},
@@ -87,7 +87,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
return type_mapping.get(param_type, {"type": "object"}) return type_mapping.get(param_type, {"type": "object"})
def _parse_type_hint(hint: str) -> Dict: def _parse_type_hint(hint: str) -> dict:
origin = get_origin(hint) origin = get_origin(hint)
args = get_args(hint) args = get_args(hint)
@@ -152,7 +152,7 @@ def _parse_type_hint(hint: str) -> Dict:
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint) raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
def _convert_type_hints_to_json_schema(func: Callable) -> Dict: def _convert_type_hints_to_json_schema(func: Callable) -> dict:
type_hints = get_type_hints(func) type_hints = get_type_hints(func)
signature = inspect.signature(func) signature = inspect.signature(func)
required = [] required = []
@@ -173,7 +173,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
return schema return schema
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]: def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
""" """
Parses a Google-style docstring to extract the function description, Parses a Google-style docstring to extract the function description,
argument descriptions, and return description. argument descriptions, and return description.
@@ -206,7 +206,7 @@ def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Option
return description, args_dict, returns return description, args_dict, returns
def get_json_schema(func: Callable) -> Dict: def get_json_schema(func: Callable) -> dict:
""" """
This function generates a JSON schema for a given function, based on its docstring and type hints. This is This function generates a JSON schema for a given function, based on its docstring and type hints. This is
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
@@ -398,7 +398,7 @@ def _compile_jinja_template(chat_template):
return self._rendered_blocks or self._generation_indices return self._rendered_blocks or self._generation_indices
@contextmanager @contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
try: try:
if self.is_active(): if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed") raise ValueError("AssistantTracker should not be reused before closed")

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved. # Copyright 2021 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +23,7 @@ import os
import random import random
import sys import sys
import warnings import warnings
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union from typing import Any, Callable, Literal, Optional, Union
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@@ -78,9 +77,9 @@ _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE
def _generate_supported_model_class_names( def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig], model_name: type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None, supported_tasks: Optional[Union[str, list[str]]] = None,
) -> List[str]: ) -> list[str]:
task_mapping = { task_mapping = {
"default": MODEL_MAPPING_NAMES, "default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES, "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
@@ -590,7 +589,7 @@ def operator_getitem(a, b):
return operator.getitem(a, b) return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = { _MANUAL_META_OVERRIDES: dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding, torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding, torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm, torch.nn.LayerNorm: torch_nn_layernorm,
@@ -716,7 +715,7 @@ class HFCacheProxy(HFProxy):
Proxy that represents an instance of `transformers.cache_utils.Cache`. Proxy that represents an instance of `transformers.cache_utils.Cache`.
""" """
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]): def install_orig_cache_cls(self, orig_cache_cls: type[Cache]):
self._orig_cache_cls = orig_cache_cls self._orig_cache_cls = orig_cache_cls
@property @property
@@ -770,8 +769,8 @@ class HFProxyableClassMeta(type):
def __new__( def __new__(
cls, cls,
name: str, name: str,
bases: Tuple[Type, ...], bases: tuple[type, ...],
attrs: Dict[str, Any], attrs: dict[str, Any],
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None, proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
): ):
cls = super().__new__(cls, name, bases, attrs) cls = super().__new__(cls, name, bases, attrs)
@@ -794,7 +793,7 @@ class HFProxyableClassMeta(type):
return cls return cls
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]: def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
""" """
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on. Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
""" """
@@ -813,7 +812,7 @@ def _proxies_to_metas(v):
return v return v
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: def create_cache_proxy_factory_fn(orig_cache_cls: type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer): if not isinstance(_CURRENT_TRACER, HFTracer):
@@ -849,7 +848,7 @@ ProxyableStaticCache = HFProxyableClassMeta(
) )
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[list[int]] = None):
if forbidden_values is None: if forbidden_values is None:
forbidden_values = [] forbidden_values = []
value = random.randint(low, high) value = random.randint(low, high)
@@ -899,8 +898,8 @@ class HFTracer(Tracer):
) )
def _generate_dummy_input( def _generate_dummy_input(
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str] self, model: "PreTrainedModel", input_name: str, shape: list[int], input_names: list[str]
) -> Dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording.""" """Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored # Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case. # from pickle, or from the "__class__" attribute in the general case.
@@ -1181,7 +1180,7 @@ class HFTracer(Tracer):
return attr_val return attr_val
# Needed for PyTorch 1.13+ # Needed for PyTorch 1.13+
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
return self._module_getattr(attr, attr_val, parameter_proxy_cache) return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs): def call_module(self, m, forward, args, kwargs):
@@ -1233,8 +1232,8 @@ class HFTracer(Tracer):
def trace( def trace(
self, self,
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None, dummy_inputs: Optional[dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True, complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph: ) -> Graph:
""" """
@@ -1422,7 +1421,7 @@ class HFTracer(Tracer):
return attribute return attribute
def get_concrete_args(model: nn.Module, input_names: List[str]): def get_concrete_args(model: nn.Module, input_names: list[str]):
sig = inspect.signature(model.forward) sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())): if not (set(input_names) <= set(sig.parameters.keys())):
@@ -1450,9 +1449,9 @@ def check_if_model_is_supported(model: "PreTrainedModel"):
def symbolic_trace( def symbolic_trace(
model: "PreTrainedModel", model: "PreTrainedModel",
input_names: Optional[List[str]] = None, input_names: Optional[list[str]] = None,
disable_check: bool = False, disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer, tracer_cls: type[HFTracer] = HFTracer,
) -> GraphModule: ) -> GraphModule:
""" """
Performs symbolic tracing on the model. Performs symbolic tracing on the model.

View File

@@ -21,12 +21,12 @@ import os
import tempfile import tempfile
import warnings import warnings
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from collections.abc import MutableMapping from collections.abc import Iterable, MutableMapping
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict from typing import Any, ContextManager, Optional, TypedDict
import numpy as np import numpy as np
from packaging import version from packaging import version
@@ -465,7 +465,7 @@ class ModelOutput(OrderedDict):
args = tuple(getattr(self, field.name) for field in fields(self)) args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining return callable, args, *remaining
def to_tuple(self) -> Tuple[Any]: def to_tuple(self) -> tuple[Any]:
""" """
Convert self to a tuple containing all the attributes/keys that are not `None`. Convert self to a tuple containing all the attributes/keys that are not `None`.
""" """
@@ -475,7 +475,7 @@ class ModelOutput(OrderedDict):
if is_torch_available(): if is_torch_available():
import torch.utils._pytree as _torch_pytree import torch.utils._pytree as _torch_pytree
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], "_torch_pytree.Context"]:
return list(output.values()), list(output.keys()) return list(output.values()), list(output.keys())
def _model_output_unflatten( def _model_output_unflatten(
@@ -542,7 +542,7 @@ class ContextManagers:
in the `fastcore` library. in the `fastcore` library.
""" """
def __init__(self, context_managers: List[ContextManager]): def __init__(self, context_managers: list[ContextManager]):
self.context_managers = context_managers self.context_managers = context_managers
self.stack = ExitStack() self.stack = ExitStack()
@@ -883,7 +883,7 @@ class LossKwargs(TypedDict, total=False):
num_items_in_batch: Optional[int] num_items_in_batch: Optional[int]
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool: def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
"""Checks whether a config dict is a timm config dict.""" """Checks whether a config dict is a timm config dict."""
return "pretrained_cfg" in config_dict return "pretrained_cfg" in config_dict
@@ -903,13 +903,13 @@ def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
# pretrained_model_path is a file # pretrained_model_path is a file
if is_file and pretrained_model_path.endswith(".json"): if is_file and pretrained_model_path.endswith(".json"):
with open(pretrained_model_path, "r") as f: with open(pretrained_model_path) as f:
config_dict = json.load(f) config_dict = json.load(f)
return is_timm_config_dict(config_dict) return is_timm_config_dict(config_dict)
# pretrained_model_path is a directory with a config.json # pretrained_model_path is a directory with a config.json
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")): if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f: with open(os.path.join(pretrained_model_path, "config.json")) as f:
config_dict = json.load(f) config_dict = json.load(f)
return is_timm_config_dict(config_dict) return is_timm_config_dict(config_dict)

View File

@@ -23,7 +23,7 @@ import tempfile
import warnings import warnings
from concurrent import futures from concurrent import futures
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@@ -168,7 +168,7 @@ def define_sagemaker_information():
return sagemaker_object return sagemaker_object
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: def http_user_agent(user_agent: Union[dict, str, None] = None) -> str:
""" """
Formats a user-agent string with basic info about a request. Formats a user-agent string with basic info about a request.
""" """
@@ -270,17 +270,17 @@ def cached_file(
def cached_files( def cached_files(
path_or_repo_id: Union[str, os.PathLike], path_or_repo_id: Union[str, os.PathLike],
filenames: List[str], filenames: list[str],
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: Optional[bool] = None, resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,
subfolder: str = "", subfolder: str = "",
repo_type: Optional[str] = None, repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None, user_agent: Optional[Union[str, dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True, _raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True, _raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True, _raise_exceptions_for_connection_errors: bool = True,
@@ -378,7 +378,7 @@ def cached_files(
if not os.path.isfile(resolved_file): if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"): if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
revision_ = "main" if revision is None else revision revision_ = "main" if revision is None else revision
raise EnvironmentError( raise OSError(
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout " f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files." f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
) )
@@ -410,7 +410,7 @@ def cached_files(
elif not _raise_exceptions_for_missing_entries: elif not _raise_exceptions_for_missing_entries:
file_counter += 1 file_counter += 1
else: else:
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.") raise OSError(f"Could not locate {filename} inside {path_or_repo_id}.")
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries # Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
if file_counter == len(full_filenames): if file_counter == len(full_filenames):
@@ -453,14 +453,14 @@ def cached_files(
except Exception as e: except Exception as e:
# We cannot recover from them # We cannot recover from them
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError): if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
raise EnvironmentError( raise OSError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier " f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing " "having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`" "`token=<your_token>`"
) from e ) from e
elif isinstance(e, RevisionNotFoundError): elif isinstance(e, RevisionNotFoundError):
raise EnvironmentError( raise OSError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists " f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at " "for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions." f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
@@ -478,7 +478,7 @@ def cached_files(
if isinstance(e, GatedRepoError): if isinstance(e, GatedRepoError):
if not _raise_exceptions_for_gated_repo: if not _raise_exceptions_for_gated_repo:
return None return None
raise EnvironmentError( raise OSError(
"You are trying to access a gated repo.\nMake sure to have access to it at " "You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}" f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
) from e ) from e
@@ -488,7 +488,7 @@ def cached_files(
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised # Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense) # even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
elif _raise_exceptions_for_missing_entries: elif _raise_exceptions_for_missing_entries:
raise EnvironmentError( raise OSError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the" f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at" f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'." " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
@@ -498,9 +498,7 @@ def cached_files(
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError): elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
if not _raise_exceptions_for_connection_errors: if not _raise_exceptions_for_connection_errors:
return None return None
raise EnvironmentError( raise OSError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}")
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
)
resolved_files = [ resolved_files = [
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames _get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
@@ -632,7 +630,7 @@ def has_file(
path_or_repo: Union[str, os.PathLike], path_or_repo: Union[str, os.PathLike],
filename: str, filename: str,
revision: Optional[str] = None, revision: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
*, *,
local_files_only: bool = False, local_files_only: bool = False,
@@ -707,19 +705,17 @@ def has_file(
return True return True
except GatedRepoError as e: except GatedRepoError as e:
logger.error(e) logger.error(e)
raise EnvironmentError( raise OSError(
f"{path_or_repo} is a gated repository. Make sure to request access at " f"{path_or_repo} is a gated repository. Make sure to request access at "
f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by " f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
"logging in with `huggingface-cli login` or by passing `token=<your_token>`." "logging in with `huggingface-cli login` or by passing `token=<your_token>`."
) from e ) from e
except RepositoryNotFoundError as e: except RepositoryNotFoundError as e:
logger.error(e) logger.error(e)
raise EnvironmentError( raise OSError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") from e
f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
) from e
except RevisionNotFoundError as e: except RevisionNotFoundError as e:
logger.error(e) logger.error(e)
raise EnvironmentError( raise OSError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions." f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
) from e ) from e
@@ -780,7 +776,7 @@ class PushToHubMixin:
self, self,
working_dir: Union[str, os.PathLike], working_dir: Union[str, os.PathLike],
repo_id: str, repo_id: str,
files_timestamps: Dict[str, float], files_timestamps: dict[str, float],
commit_message: Optional[str] = None, commit_message: Optional[str] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
create_pr: bool = False, create_pr: bool = False,
@@ -867,7 +863,7 @@ class PushToHubMixin:
safe_serialization: bool = True, safe_serialization: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
commit_description: Optional[str] = None, commit_description: Optional[str] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
**deprecated_kwargs, **deprecated_kwargs,
) -> str: ) -> str:
""" """
@@ -1101,7 +1097,7 @@ def get_checkpoint_shard_files(
if not os.path.isfile(index_filename): if not os.path.isfile(index_filename):
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
with open(index_filename, "r") as f: with open(index_filename) as f:
index = json.loads(f.read()) index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values())) shard_filenames = sorted(set(index["weight_map"].values()))
@@ -1136,7 +1132,7 @@ def get_checkpoint_shard_files(
def create_and_tag_model_card( def create_and_tag_model_card(
repo_id: str, repo_id: str,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
token: Optional[str] = None, token: Optional[str] = None,
ignore_metadata_errors: bool = False, ignore_metadata_errors: bool = False,
): ):

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face # Copyright 2020 Optuna, Hugging Face
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved. # Copyright 2020 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2020 Hugging Face # Copyright 2020 Hugging Face
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import os import os
from typing import Dict, Optional, Union from typing import Optional, Union
from packaging import version from packaging import version
@@ -31,7 +31,7 @@ def find_adapter_config_file(
cache_dir: Optional[Union[str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False, force_download: bool = False,
resume_download: Optional[bool] = None, resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None, proxies: Optional[dict[str, str]] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
local_files_only: bool = False, local_files_only: bool = False,

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: sentencepiece_model.proto # source: sentencepiece_model.proto
"""Generated protocol buffer code.""" """Generated protocol buffer code."""