[3/N] Use pyupgrade --py39-plus to improve code (#36936)
Use pyupgrade --py39-plus to improve code Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -90,7 +90,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
|
|||||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[str]`: The list of relative imports in the module.
|
`list[str]`: The list of relative imports in the module.
|
||||||
"""
|
"""
|
||||||
with open(module_file, encoding="utf-8") as f:
|
with open(module_file, encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
@@ -112,7 +112,7 @@ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]
|
|||||||
module_file (`str` or `os.PathLike`): The module file to inspect.
|
module_file (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
`list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
||||||
of module files a given module needs.
|
of module files a given module needs.
|
||||||
"""
|
"""
|
||||||
no_change = False
|
no_change = False
|
||||||
@@ -144,7 +144,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
|||||||
filename (`str` or `os.PathLike`): The module file to inspect.
|
filename (`str` or `os.PathLike`): The module file to inspect.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[str]`: The list of all packages required to use the input module.
|
`list[str]`: The list of all packages required to use the input module.
|
||||||
"""
|
"""
|
||||||
with open(filename, encoding="utf-8") as f:
|
with open(filename, encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
@@ -175,7 +175,7 @@ def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
|||||||
filename (`str` or `os.PathLike`): The module file to check.
|
filename (`str` or `os.PathLike`): The module file to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`List[str]`: The list of relative imports in the file.
|
`list[str]`: The list of relative imports in the file.
|
||||||
"""
|
"""
|
||||||
imports = get_imports(filename)
|
imports = get_imports(filename)
|
||||||
missing_packages = []
|
missing_packages = []
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team.
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# coding=utf-8
|
|
||||||
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -16,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import FrozenSet
|
|
||||||
|
|
||||||
from huggingface_hub import get_full_repo_name # for backward compatibility
|
from huggingface_hub import get_full_repo_name # for backward compatibility
|
||||||
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
|
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
|
||||||
@@ -300,8 +298,8 @@ def check_min_version(min_version):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache
|
||||||
def get_available_devices() -> FrozenSet[str]:
|
def get_available_devices() -> frozenset[str]:
|
||||||
"""
|
"""
|
||||||
Returns a frozenset of devices available for the current PyTorch installation.
|
Returns a frozenset of devices available for the current PyTorch installation.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team.
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 The HuggingFace Inc. team.
|
# Copyright 2023 The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -17,7 +16,8 @@
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -75,9 +75,9 @@ def verify_out_features_out_indices(
|
|||||||
|
|
||||||
|
|
||||||
def _align_output_features_output_indices(
|
def _align_output_features_output_indices(
|
||||||
out_features: Optional[List[str]],
|
out_features: Optional[list[str]],
|
||||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
out_indices: Optional[Union[list[int], tuple[int]]],
|
||||||
stage_names: List[str],
|
stage_names: list[str],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
|
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
|
||||||
@@ -106,10 +106,10 @@ def _align_output_features_output_indices(
|
|||||||
|
|
||||||
|
|
||||||
def get_aligned_output_features_output_indices(
|
def get_aligned_output_features_output_indices(
|
||||||
out_features: Optional[List[str]],
|
out_features: Optional[list[str]],
|
||||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
out_indices: Optional[Union[list[int], tuple[int]]],
|
||||||
stage_names: List[str],
|
stage_names: list[str],
|
||||||
) -> Tuple[List[str], List[int]]:
|
) -> tuple[list[str], list[int]]:
|
||||||
"""
|
"""
|
||||||
Get the `out_features` and `out_indices` so that they are aligned.
|
Get the `out_features` and `out_indices` so that they are aligned.
|
||||||
|
|
||||||
@@ -198,7 +198,7 @@ class BackboneMixin:
|
|||||||
return self._out_features
|
return self._out_features
|
||||||
|
|
||||||
@out_features.setter
|
@out_features.setter
|
||||||
def out_features(self, out_features: List[str]):
|
def out_features(self, out_features: list[str]):
|
||||||
"""
|
"""
|
||||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||||
"""
|
"""
|
||||||
@@ -211,7 +211,7 @@ class BackboneMixin:
|
|||||||
return self._out_indices
|
return self._out_indices
|
||||||
|
|
||||||
@out_indices.setter
|
@out_indices.setter
|
||||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
|
||||||
"""
|
"""
|
||||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||||
"""
|
"""
|
||||||
@@ -264,7 +264,7 @@ class BackboneConfigMixin:
|
|||||||
return self._out_features
|
return self._out_features
|
||||||
|
|
||||||
@out_features.setter
|
@out_features.setter
|
||||||
def out_features(self, out_features: List[str]):
|
def out_features(self, out_features: list[str]):
|
||||||
"""
|
"""
|
||||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||||
"""
|
"""
|
||||||
@@ -277,7 +277,7 @@ class BackboneConfigMixin:
|
|||||||
return self._out_indices
|
return self._out_indices
|
||||||
|
|
||||||
@out_indices.setter
|
@out_indices.setter
|
||||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
|
||||||
"""
|
"""
|
||||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import types
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ class DocstringParsingException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
def _get_json_schema_type(param_type: str) -> dict[str, str]:
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
int: {"type": "integer"},
|
int: {"type": "integer"},
|
||||||
float: {"type": "number"},
|
float: {"type": "number"},
|
||||||
@@ -87,7 +87,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
|
|||||||
return type_mapping.get(param_type, {"type": "object"})
|
return type_mapping.get(param_type, {"type": "object"})
|
||||||
|
|
||||||
|
|
||||||
def _parse_type_hint(hint: str) -> Dict:
|
def _parse_type_hint(hint: str) -> dict:
|
||||||
origin = get_origin(hint)
|
origin = get_origin(hint)
|
||||||
args = get_args(hint)
|
args = get_args(hint)
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ def _parse_type_hint(hint: str) -> Dict:
|
|||||||
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||||
|
|
||||||
|
|
||||||
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
|
||||||
type_hints = get_type_hints(func)
|
type_hints = get_type_hints(func)
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
required = []
|
required = []
|
||||||
@@ -173,7 +173,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
|
def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Parses a Google-style docstring to extract the function description,
|
Parses a Google-style docstring to extract the function description,
|
||||||
argument descriptions, and return description.
|
argument descriptions, and return description.
|
||||||
@@ -206,7 +206,7 @@ def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Option
|
|||||||
return description, args_dict, returns
|
return description, args_dict, returns
|
||||||
|
|
||||||
|
|
||||||
def get_json_schema(func: Callable) -> Dict:
|
def get_json_schema(func: Callable) -> dict:
|
||||||
"""
|
"""
|
||||||
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||||
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||||
@@ -398,7 +398,7 @@ def _compile_jinja_template(chat_template):
|
|||||||
return self._rendered_blocks or self._generation_indices
|
return self._rendered_blocks or self._generation_indices
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
|
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
|
||||||
try:
|
try:
|
||||||
if self.is_active():
|
if self.is_active():
|
||||||
raise ValueError("AssistantTracker should not be reused before closed")
|
raise ValueError("AssistantTracker should not be reused before closed")
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -24,7 +23,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@@ -78,9 +77,9 @@ _IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE
|
|||||||
|
|
||||||
|
|
||||||
def _generate_supported_model_class_names(
|
def _generate_supported_model_class_names(
|
||||||
model_name: Type[PretrainedConfig],
|
model_name: type[PretrainedConfig],
|
||||||
supported_tasks: Optional[Union[str, List[str]]] = None,
|
supported_tasks: Optional[Union[str, list[str]]] = None,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
task_mapping = {
|
task_mapping = {
|
||||||
"default": MODEL_MAPPING_NAMES,
|
"default": MODEL_MAPPING_NAMES,
|
||||||
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
|
||||||
@@ -590,7 +589,7 @@ def operator_getitem(a, b):
|
|||||||
return operator.getitem(a, b)
|
return operator.getitem(a, b)
|
||||||
|
|
||||||
|
|
||||||
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
_MANUAL_META_OVERRIDES: dict[Callable, Callable] = {
|
||||||
torch.nn.Embedding: torch_nn_embedding,
|
torch.nn.Embedding: torch_nn_embedding,
|
||||||
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
torch.nn.functional.embedding: torch_nn_functional_embedding,
|
||||||
torch.nn.LayerNorm: torch_nn_layernorm,
|
torch.nn.LayerNorm: torch_nn_layernorm,
|
||||||
@@ -716,7 +715,7 @@ class HFCacheProxy(HFProxy):
|
|||||||
Proxy that represents an instance of `transformers.cache_utils.Cache`.
|
Proxy that represents an instance of `transformers.cache_utils.Cache`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
|
def install_orig_cache_cls(self, orig_cache_cls: type[Cache]):
|
||||||
self._orig_cache_cls = orig_cache_cls
|
self._orig_cache_cls = orig_cache_cls
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -770,8 +769,8 @@ class HFProxyableClassMeta(type):
|
|||||||
def __new__(
|
def __new__(
|
||||||
cls,
|
cls,
|
||||||
name: str,
|
name: str,
|
||||||
bases: Tuple[Type, ...],
|
bases: tuple[type, ...],
|
||||||
attrs: Dict[str, Any],
|
attrs: dict[str, Any],
|
||||||
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
|
||||||
):
|
):
|
||||||
cls = super().__new__(cls, name, bases, attrs)
|
cls = super().__new__(cls, name, bases, attrs)
|
||||||
@@ -794,7 +793,7 @@ class HFProxyableClassMeta(type):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
|
def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
|
||||||
"""
|
"""
|
||||||
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
|
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
|
||||||
"""
|
"""
|
||||||
@@ -813,7 +812,7 @@ def _proxies_to_metas(v):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
def create_cache_proxy_factory_fn(orig_cache_cls: type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||||
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||||
global _CURRENT_TRACER
|
global _CURRENT_TRACER
|
||||||
if not isinstance(_CURRENT_TRACER, HFTracer):
|
if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||||
@@ -849,7 +848,7 @@ ProxyableStaticCache = HFProxyableClassMeta(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[list[int]] = None):
|
||||||
if forbidden_values is None:
|
if forbidden_values is None:
|
||||||
forbidden_values = []
|
forbidden_values = []
|
||||||
value = random.randint(low, high)
|
value = random.randint(low, high)
|
||||||
@@ -899,8 +898,8 @@ class HFTracer(Tracer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _generate_dummy_input(
|
def _generate_dummy_input(
|
||||||
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
|
self, model: "PreTrainedModel", input_name: str, shape: list[int], input_names: list[str]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""Generates dummy input for model inference recording."""
|
"""Generates dummy input for model inference recording."""
|
||||||
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
|
||||||
# from pickle, or from the "__class__" attribute in the general case.
|
# from pickle, or from the "__class__" attribute in the general case.
|
||||||
@@ -1181,7 +1180,7 @@ class HFTracer(Tracer):
|
|||||||
return attr_val
|
return attr_val
|
||||||
|
|
||||||
# Needed for PyTorch 1.13+
|
# Needed for PyTorch 1.13+
|
||||||
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
|
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
|
||||||
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
|
||||||
|
|
||||||
def call_module(self, m, forward, args, kwargs):
|
def call_module(self, m, forward, args, kwargs):
|
||||||
@@ -1233,8 +1232,8 @@ class HFTracer(Tracer):
|
|||||||
def trace(
|
def trace(
|
||||||
self,
|
self,
|
||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[dict[str, Any]] = None,
|
||||||
dummy_inputs: Optional[Dict[str, Any]] = None,
|
dummy_inputs: Optional[dict[str, Any]] = None,
|
||||||
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
|
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
|
||||||
) -> Graph:
|
) -> Graph:
|
||||||
"""
|
"""
|
||||||
@@ -1422,7 +1421,7 @@ class HFTracer(Tracer):
|
|||||||
return attribute
|
return attribute
|
||||||
|
|
||||||
|
|
||||||
def get_concrete_args(model: nn.Module, input_names: List[str]):
|
def get_concrete_args(model: nn.Module, input_names: list[str]):
|
||||||
sig = inspect.signature(model.forward)
|
sig = inspect.signature(model.forward)
|
||||||
|
|
||||||
if not (set(input_names) <= set(sig.parameters.keys())):
|
if not (set(input_names) <= set(sig.parameters.keys())):
|
||||||
@@ -1450,9 +1449,9 @@ def check_if_model_is_supported(model: "PreTrainedModel"):
|
|||||||
|
|
||||||
def symbolic_trace(
|
def symbolic_trace(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
input_names: Optional[List[str]] = None,
|
input_names: Optional[list[str]] = None,
|
||||||
disable_check: bool = False,
|
disable_check: bool = False,
|
||||||
tracer_cls: Type[HFTracer] = HFTracer,
|
tracer_cls: type[HFTracer] = HFTracer,
|
||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
Performs symbolic tracing on the model.
|
Performs symbolic tracing on the model.
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import Iterable, MutableMapping
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, TypedDict
|
from typing import Any, ContextManager, Optional, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -465,7 +465,7 @@ class ModelOutput(OrderedDict):
|
|||||||
args = tuple(getattr(self, field.name) for field in fields(self))
|
args = tuple(getattr(self, field.name) for field in fields(self))
|
||||||
return callable, args, *remaining
|
return callable, args, *remaining
|
||||||
|
|
||||||
def to_tuple(self) -> Tuple[Any]:
|
def to_tuple(self) -> tuple[Any]:
|
||||||
"""
|
"""
|
||||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||||
"""
|
"""
|
||||||
@@ -475,7 +475,7 @@ class ModelOutput(OrderedDict):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch.utils._pytree as _torch_pytree
|
import torch.utils._pytree as _torch_pytree
|
||||||
|
|
||||||
def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
|
def _model_output_flatten(output: ModelOutput) -> tuple[list[Any], "_torch_pytree.Context"]:
|
||||||
return list(output.values()), list(output.keys())
|
return list(output.values()), list(output.keys())
|
||||||
|
|
||||||
def _model_output_unflatten(
|
def _model_output_unflatten(
|
||||||
@@ -542,7 +542,7 @@ class ContextManagers:
|
|||||||
in the `fastcore` library.
|
in the `fastcore` library.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, context_managers: List[ContextManager]):
|
def __init__(self, context_managers: list[ContextManager]):
|
||||||
self.context_managers = context_managers
|
self.context_managers = context_managers
|
||||||
self.stack = ExitStack()
|
self.stack = ExitStack()
|
||||||
|
|
||||||
@@ -883,7 +883,7 @@ class LossKwargs(TypedDict, total=False):
|
|||||||
num_items_in_batch: Optional[int]
|
num_items_in_batch: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
def is_timm_config_dict(config_dict: Dict[str, Any]) -> bool:
|
def is_timm_config_dict(config_dict: dict[str, Any]) -> bool:
|
||||||
"""Checks whether a config dict is a timm config dict."""
|
"""Checks whether a config dict is a timm config dict."""
|
||||||
return "pretrained_cfg" in config_dict
|
return "pretrained_cfg" in config_dict
|
||||||
|
|
||||||
@@ -903,13 +903,13 @@ def is_timm_local_checkpoint(pretrained_model_path: str) -> bool:
|
|||||||
|
|
||||||
# pretrained_model_path is a file
|
# pretrained_model_path is a file
|
||||||
if is_file and pretrained_model_path.endswith(".json"):
|
if is_file and pretrained_model_path.endswith(".json"):
|
||||||
with open(pretrained_model_path, "r") as f:
|
with open(pretrained_model_path) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
return is_timm_config_dict(config_dict)
|
return is_timm_config_dict(config_dict)
|
||||||
|
|
||||||
# pretrained_model_path is a directory with a config.json
|
# pretrained_model_path is a directory with a config.json
|
||||||
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
|
if is_dir and os.path.exists(os.path.join(pretrained_model_path, "config.json")):
|
||||||
with open(os.path.join(pretrained_model_path, "config.json"), "r") as f:
|
with open(os.path.join(pretrained_model_path, "config.json")) as f:
|
||||||
config_dict = json.load(f)
|
config_dict = json.load(f)
|
||||||
return is_timm_config_dict(config_dict)
|
return is_timm_config_dict(config_dict)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import tempfile
|
|||||||
import warnings
|
import warnings
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ def define_sagemaker_information():
|
|||||||
return sagemaker_object
|
return sagemaker_object
|
||||||
|
|
||||||
|
|
||||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
def http_user_agent(user_agent: Union[dict, str, None] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Formats a user-agent string with basic info about a request.
|
Formats a user-agent string with basic info about a request.
|
||||||
"""
|
"""
|
||||||
@@ -270,17 +270,17 @@ def cached_file(
|
|||||||
|
|
||||||
def cached_files(
|
def cached_files(
|
||||||
path_or_repo_id: Union[str, os.PathLike],
|
path_or_repo_id: Union[str, os.PathLike],
|
||||||
filenames: List[str],
|
filenames: list[str],
|
||||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: Optional[bool] = None,
|
resume_download: Optional[bool] = None,
|
||||||
proxies: Optional[Dict[str, str]] = None,
|
proxies: Optional[dict[str, str]] = None,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
subfolder: str = "",
|
subfolder: str = "",
|
||||||
repo_type: Optional[str] = None,
|
repo_type: Optional[str] = None,
|
||||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
user_agent: Optional[Union[str, dict[str, str]]] = None,
|
||||||
_raise_exceptions_for_gated_repo: bool = True,
|
_raise_exceptions_for_gated_repo: bool = True,
|
||||||
_raise_exceptions_for_missing_entries: bool = True,
|
_raise_exceptions_for_missing_entries: bool = True,
|
||||||
_raise_exceptions_for_connection_errors: bool = True,
|
_raise_exceptions_for_connection_errors: bool = True,
|
||||||
@@ -378,7 +378,7 @@ def cached_files(
|
|||||||
if not os.path.isfile(resolved_file):
|
if not os.path.isfile(resolved_file):
|
||||||
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
|
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
|
||||||
revision_ = "main" if revision is None else revision
|
revision_ = "main" if revision is None else revision
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
|
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
|
||||||
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
|
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
|
||||||
)
|
)
|
||||||
@@ -410,7 +410,7 @@ def cached_files(
|
|||||||
elif not _raise_exceptions_for_missing_entries:
|
elif not _raise_exceptions_for_missing_entries:
|
||||||
file_counter += 1
|
file_counter += 1
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.")
|
raise OSError(f"Could not locate {filename} inside {path_or_repo_id}.")
|
||||||
|
|
||||||
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
|
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
|
||||||
if file_counter == len(full_filenames):
|
if file_counter == len(full_filenames):
|
||||||
@@ -453,14 +453,14 @@ def cached_files(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# We cannot recover from them
|
# We cannot recover from them
|
||||||
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
|
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
|
||||||
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
|
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
|
||||||
"`token=<your_token>`"
|
"`token=<your_token>`"
|
||||||
) from e
|
) from e
|
||||||
elif isinstance(e, RevisionNotFoundError):
|
elif isinstance(e, RevisionNotFoundError):
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||||
"for this model name. Check the model page at "
|
"for this model name. Check the model page at "
|
||||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||||
@@ -478,7 +478,7 @@ def cached_files(
|
|||||||
if isinstance(e, GatedRepoError):
|
if isinstance(e, GatedRepoError):
|
||||||
if not _raise_exceptions_for_gated_repo:
|
if not _raise_exceptions_for_gated_repo:
|
||||||
return None
|
return None
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
"You are trying to access a gated repo.\nMake sure to have access to it at "
|
"You are trying to access a gated repo.\nMake sure to have access to it at "
|
||||||
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
|
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
|
||||||
) from e
|
) from e
|
||||||
@@ -488,7 +488,7 @@ def cached_files(
|
|||||||
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
|
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
|
||||||
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
|
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
|
||||||
elif _raise_exceptions_for_missing_entries:
|
elif _raise_exceptions_for_missing_entries:
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
|
||||||
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
|
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||||
@@ -498,9 +498,7 @@ def cached_files(
|
|||||||
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
|
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
|
||||||
if not _raise_exceptions_for_connection_errors:
|
if not _raise_exceptions_for_connection_errors:
|
||||||
return None
|
return None
|
||||||
raise EnvironmentError(
|
raise OSError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}")
|
||||||
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
resolved_files = [
|
resolved_files = [
|
||||||
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
|
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
|
||||||
@@ -632,7 +630,7 @@ def has_file(
|
|||||||
path_or_repo: Union[str, os.PathLike],
|
path_or_repo: Union[str, os.PathLike],
|
||||||
filename: str,
|
filename: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
proxies: Optional[Dict[str, str]] = None,
|
proxies: Optional[dict[str, str]] = None,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
*,
|
*,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
@@ -707,19 +705,17 @@ def has_file(
|
|||||||
return True
|
return True
|
||||||
except GatedRepoError as e:
|
except GatedRepoError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"{path_or_repo} is a gated repository. Make sure to request access at "
|
f"{path_or_repo} is a gated repository. Make sure to request access at "
|
||||||
f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
|
f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
|
||||||
"logging in with `huggingface-cli login` or by passing `token=<your_token>`."
|
"logging in with `huggingface-cli login` or by passing `token=<your_token>`."
|
||||||
) from e
|
) from e
|
||||||
except RepositoryNotFoundError as e:
|
except RepositoryNotFoundError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise EnvironmentError(
|
raise OSError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.") from e
|
||||||
f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
|
|
||||||
) from e
|
|
||||||
except RevisionNotFoundError as e:
|
except RevisionNotFoundError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
raise EnvironmentError(
|
raise OSError(
|
||||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||||
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
|
f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||||
) from e
|
) from e
|
||||||
@@ -780,7 +776,7 @@ class PushToHubMixin:
|
|||||||
self,
|
self,
|
||||||
working_dir: Union[str, os.PathLike],
|
working_dir: Union[str, os.PathLike],
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
files_timestamps: Dict[str, float],
|
files_timestamps: dict[str, float],
|
||||||
commit_message: Optional[str] = None,
|
commit_message: Optional[str] = None,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
create_pr: bool = False,
|
create_pr: bool = False,
|
||||||
@@ -867,7 +863,7 @@ class PushToHubMixin:
|
|||||||
safe_serialization: bool = True,
|
safe_serialization: bool = True,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
commit_description: Optional[str] = None,
|
commit_description: Optional[str] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
**deprecated_kwargs,
|
**deprecated_kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -1101,7 +1097,7 @@ def get_checkpoint_shard_files(
|
|||||||
if not os.path.isfile(index_filename):
|
if not os.path.isfile(index_filename):
|
||||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||||
|
|
||||||
with open(index_filename, "r") as f:
|
with open(index_filename) as f:
|
||||||
index = json.loads(f.read())
|
index = json.loads(f.read())
|
||||||
|
|
||||||
shard_filenames = sorted(set(index["weight_map"].values()))
|
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||||
@@ -1136,7 +1132,7 @@ def get_checkpoint_shard_files(
|
|||||||
|
|
||||||
def create_and_tag_model_card(
|
def create_and_tag_model_card(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
ignore_metadata_errors: bool = False,
|
ignore_metadata_errors: bool = False,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 Optuna, Hugging Face
|
# Copyright 2020 Optuna, Hugging Face
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 Hugging Face
|
# Copyright 2020 Hugging Face
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ def find_adapter_config_file(
|
|||||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||||
force_download: bool = False,
|
force_download: bool = False,
|
||||||
resume_download: Optional[bool] = None,
|
resume_download: Optional[bool] = None,
|
||||||
proxies: Optional[Dict[str, str]] = None,
|
proxies: Optional[dict[str, str]] = None,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: sentencepiece_model.proto
|
# source: sentencepiece_model.proto
|
||||||
"""Generated protocol buffer code."""
|
"""Generated protocol buffer code."""
|
||||||
|
|||||||
Reference in New Issue
Block a user