No more Tuple, List, Dict (#38797)
* No more Tuple, List, Dict * make fixup * More style fixes * Docstring fixes with regex replacement * Trigger tests * Redo fixes after rebase * Fix copies * [test all] * update * [test all] * update * [test all] * make style after rebase * Patch the hf_argparser test * Patch the hf_argparser test * style fixes * style fixes * style fixes * Fix docstrings in Cohere test * [test all] --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import torch
|
||||
@@ -339,7 +339,7 @@ def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
@@ -374,7 +374,7 @@ def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||
return last_dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
@@ -644,7 +644,7 @@ def _get_tied_weight_keys(module: nn.Module, prefix=""):
|
||||
return tied_weight_keys
|
||||
|
||||
|
||||
def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
|
||||
def _find_disjoint(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], list[str]]:
|
||||
filtered_tensors = []
|
||||
for shared in tensors:
|
||||
if len(shared) < 2:
|
||||
@@ -675,7 +675,7 @@ def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor])
|
||||
return shared_tensors, disjoint_tensors
|
||||
|
||||
|
||||
def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
|
||||
def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], set[str]]:
|
||||
shared_tensors = []
|
||||
identical = []
|
||||
for shared in tensors:
|
||||
@@ -738,21 +738,21 @@ def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model(
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Dict,
|
||||
state_dict: dict,
|
||||
shard_file: str,
|
||||
expected_keys: List[str],
|
||||
reverse_renaming_mapping: Dict[str, str],
|
||||
device_map: Optional[Dict] = None,
|
||||
expected_keys: list[str],
|
||||
reverse_renaming_mapping: dict[str, str],
|
||||
device_map: Optional[dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
disk_offload_index: Optional[Dict] = None,
|
||||
disk_offload_index: Optional[dict] = None,
|
||||
cpu_offload_folder: Optional[str] = None,
|
||||
cpu_offload_index: Optional[Dict] = None,
|
||||
cpu_offload_index: Optional[dict] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
is_safetensors: bool = False,
|
||||
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||
) -> tuple[Optional[dict], Optional[dict]]:
|
||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||
from `shard_file`, which is the actual state dict file on disk.
|
||||
@@ -998,7 +998,7 @@ def _get_resolved_checkpoint_files(
|
||||
use_safetensors: bool,
|
||||
cache_dir: str,
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict[str, str]],
|
||||
proxies: Optional[dict[str, str]],
|
||||
local_files_only: bool,
|
||||
token: Optional[Union[str, bool]],
|
||||
user_agent: dict,
|
||||
@@ -1006,7 +1006,7 @@ def _get_resolved_checkpoint_files(
|
||||
commit_hash: Optional[str],
|
||||
is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
|
||||
transformers_explicit_filename: Optional[str] = None,
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
) -> tuple[Optional[list[str]], Optional[dict]]:
|
||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||
checkpoints are sharded.
|
||||
This function will download the data if necessary.
|
||||
@@ -1315,13 +1315,13 @@ def _get_resolved_checkpoint_files(
|
||||
|
||||
def _get_torch_dtype(
|
||||
cls,
|
||||
torch_dtype: Optional[Union[str, torch.dtype, Dict]],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
torch_dtype: Optional[Union[str, torch.dtype, dict]],
|
||||
checkpoint_files: Optional[list[str]],
|
||||
config: PretrainedConfig,
|
||||
sharded_metadata: Optional[Dict],
|
||||
state_dict: Optional[Dict],
|
||||
sharded_metadata: Optional[dict],
|
||||
state_dict: Optional[dict],
|
||||
weights_only: bool,
|
||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||
inferred dtype. We do the following:
|
||||
1. If torch_dtype is not None, we use that dtype
|
||||
@@ -1395,12 +1395,12 @@ def _get_torch_dtype(
|
||||
|
||||
def _get_device_map(
|
||||
model: "PreTrainedModel",
|
||||
device_map: Optional[Union[str, Dict]],
|
||||
max_memory: Optional[Dict],
|
||||
device_map: Optional[Union[str, dict]],
|
||||
max_memory: Optional[dict],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
keep_in_fp32_regex: Optional[re.Pattern],
|
||||
) -> Dict:
|
||||
) -> dict:
|
||||
"""Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
|
||||
Otherwise, we check for any device inconsistencies in the device_map.
|
||||
"""
|
||||
@@ -1472,12 +1472,12 @@ def _get_device_map(
|
||||
def _find_missing_and_unexpected_keys(
|
||||
cls,
|
||||
model: "PreTrainedModel",
|
||||
original_checkpoint_keys: List[str],
|
||||
checkpoint_keys: List[str],
|
||||
original_checkpoint_keys: list[str],
|
||||
checkpoint_keys: list[str],
|
||||
loading_base_model_from_task_state_dict: bool,
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
device_map: Dict,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
device_map: dict,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
|
||||
(keys found in the loaded state dict keys, but that are NOT part of the model parameters)
|
||||
"""
|
||||
@@ -1531,13 +1531,13 @@ def _find_missing_and_unexpected_keys(
|
||||
|
||||
def _find_mismatched_keys(
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Optional[Dict],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
state_dict: Optional[dict],
|
||||
checkpoint_files: Optional[list[str]],
|
||||
ignore_mismatched_sizes: bool,
|
||||
keys_to_rename_mapping: Dict[str, str],
|
||||
keys_to_rename_mapping: dict[str, str],
|
||||
is_quantized: bool,
|
||||
weights_only: bool,
|
||||
) -> Tuple[List[str], List[Tuple[int, int]]]:
|
||||
) -> tuple[list[str], list[tuple[int, int]]]:
|
||||
"""
|
||||
Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
|
||||
is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
|
||||
@@ -1710,7 +1710,7 @@ class ModuleUtilsMixin:
|
||||
return extended_attention_mask
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
|
||||
self, attention_mask: Tensor, input_shape: tuple[int], device: torch.device = None, dtype: torch.float = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
@@ -1718,7 +1718,7 @@ class ModuleUtilsMixin:
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
input_shape (`tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
|
||||
Returns:
|
||||
@@ -1853,7 +1853,7 @@ class ModuleUtilsMixin:
|
||||
|
||||
return sum(total_numel)
|
||||
|
||||
def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
def estimate_tokens(self, input_dict: dict[str, Union[torch.Tensor, Any]]) -> int:
|
||||
"""
|
||||
Helper function to estimate the total number of tokens from the model inputs.
|
||||
|
||||
@@ -1875,7 +1875,7 @@ class ModuleUtilsMixin:
|
||||
return 0
|
||||
|
||||
def floating_point_ops(
|
||||
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
|
||||
self, input_dict: dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
|
||||
@@ -2003,9 +2003,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
_supports_attention_backend = False
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
def dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
`Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
|
||||
`dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
|
||||
"""
|
||||
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
|
||||
|
||||
@@ -2108,13 +2108,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
||||
delattr(self.config, "gradient_checkpointing")
|
||||
|
||||
def add_model_tags(self, tags: Union[List[str], str]) -> None:
|
||||
def add_model_tags(self, tags: Union[list[str], str]) -> None:
|
||||
r"""
|
||||
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
|
||||
not overwrite existing tags in the model.
|
||||
|
||||
Args:
|
||||
tags (`Union[List[str], str]`):
|
||||
tags (`Union[list[str], str]`):
|
||||
The desired tags to inject in the model
|
||||
|
||||
Examples:
|
||||
@@ -2203,7 +2203,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -2400,7 +2400,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
@@ -2693,8 +2693,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
def _tie_encoder_decoder_weights(
|
||||
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
|
||||
):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
tied_weights: List[str] = []
|
||||
uninitialized_encoder_weights: list[str] = []
|
||||
tied_weights: list[str] = []
|
||||
if decoder.__class__ != encoder.__class__:
|
||||
logger.info(
|
||||
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
|
||||
@@ -2706,7 +2706,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
encoder_pointer: nn.Module,
|
||||
module_name: str,
|
||||
base_encoder_name: str,
|
||||
uninitialized_encoder_weights: List[str],
|
||||
uninitialized_encoder_weights: list[str],
|
||||
depth=0,
|
||||
total_decoder_name="",
|
||||
total_encoder_name="",
|
||||
@@ -2809,7 +2809,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
||||
|
||||
Returns:
|
||||
`List[str]`: List of modules that should not be split
|
||||
`list[str]`: List of modules that should not be split
|
||||
"""
|
||||
_no_split_modules = set()
|
||||
modules_to_check = [self]
|
||||
@@ -3289,7 +3289,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
||||
)
|
||||
|
||||
def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
||||
def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
|
||||
raise NotImplementedError(
|
||||
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
||||
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
||||
@@ -3312,12 +3312,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# since from_pretrained(...) calls tie weights anyways
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
|
||||
def prune_heads(self, heads_to_prune: dict[int, list[int]]):
|
||||
"""
|
||||
Prunes heads of the base model.
|
||||
|
||||
Arguments:
|
||||
heads_to_prune (`Dict[int, List[int]]`):
|
||||
heads_to_prune (`dict[int, list[int]]`):
|
||||
Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
|
||||
to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
|
||||
layer 1 and heads 2 and 3 on layer 2.
|
||||
@@ -3486,7 +3486,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
|
||||
keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
|
||||
disable this behaviours by setting `save_peft_format` to `False`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
kwargs (`dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
@@ -3997,7 +3997,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
@classmethod
|
||||
@restore_default_torch_dtype
|
||||
def from_pretrained(
|
||||
cls: Type[SpecificPreTrainedModelType],
|
||||
cls: type[SpecificPreTrainedModelType],
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
*model_args,
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
@@ -4057,7 +4057,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
save directory.
|
||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
state_dict (`dict[str, torch.Tensor]`, *optional*):
|
||||
A state dictionary to use instead of a state dictionary loaded from saved weights file.
|
||||
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own
|
||||
@@ -4082,7 +4082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
proxies (`dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
@@ -4131,7 +4131,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
</Tip>
|
||||
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
|
||||
device_map (`str` or `dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
|
||||
@@ -4179,7 +4179,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
|
||||
dictionaries and any types added via torch.serialization.add_safe_globals().
|
||||
When set to False, we can load wrapper tensor subclass weights.
|
||||
key_mapping (`Dict[str, str], *optional*):
|
||||
key_mapping (`dict[str, str], *optional*):
|
||||
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
|
||||
architecture, but was not converted accordingly.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
@@ -4799,7 +4799,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]:
|
||||
def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
|
||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
||||
# This rename is logged.
|
||||
@@ -4826,8 +4826,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
def _get_key_renaming_mapping(
|
||||
self,
|
||||
checkpoint_keys: List[str],
|
||||
key_mapping: Optional[Dict[str, str]] = None,
|
||||
checkpoint_keys: list[str],
|
||||
key_mapping: Optional[dict[str, str]] = None,
|
||||
loading_base_model_from_task_state_dict: bool = False,
|
||||
loading_task_model_from_base_state_dict: bool = False,
|
||||
):
|
||||
@@ -4885,7 +4885,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
return key_renaming_mapping
|
||||
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
|
||||
def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
|
||||
"""
|
||||
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
|
||||
Do nothing by default, but can be overridden in particular models.
|
||||
@@ -4903,19 +4903,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
model: "PreTrainedModel",
|
||||
state_dict: Optional[Dict],
|
||||
checkpoint_files: Optional[List[str]],
|
||||
state_dict: Optional[dict],
|
||||
checkpoint_files: Optional[list[str]],
|
||||
pretrained_model_name_or_path: Optional[str],
|
||||
ignore_mismatched_sizes: bool = False,
|
||||
sharded_metadata: Optional[Dict] = None,
|
||||
device_map: Optional[Dict] = None,
|
||||
sharded_metadata: Optional[dict] = None,
|
||||
device_map: Optional[dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
offload_state_dict: Optional[bool] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
key_mapping: Optional[Dict[str, str]] = None,
|
||||
key_mapping: Optional[dict[str, str]] = None,
|
||||
weights_only: bool = True,
|
||||
):
|
||||
# Useful flags
|
||||
@@ -5485,8 +5485,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
def _move_missing_keys_from_meta_to_cpu(
|
||||
self,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
dtype: Optional[torch.dtype],
|
||||
hf_quantizer: Optional[HfQuantizer],
|
||||
) -> "PreTrainedModel":
|
||||
@@ -5520,7 +5520,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
def _initialize_missing_keys(
|
||||
self,
|
||||
loaded_keys: List[str],
|
||||
loaded_keys: list[str],
|
||||
ignore_mismatched_sizes: bool,
|
||||
is_quantized: bool,
|
||||
) -> "PreTrainedModel":
|
||||
@@ -5846,7 +5846,7 @@ class SQuADHead(nn.Module):
|
||||
is_impossible: Optional[torch.LongTensor] = None,
|
||||
p_mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = False,
|
||||
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||
) -> Union[SquadHeadOutput, tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
|
||||
@@ -6090,7 +6090,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, hf_quantizer: Optional[HfQuantizer]):
|
||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
|
||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||
the model, which is actually the loading speed bottleneck.
|
||||
|
||||
Reference in New Issue
Block a user