From 071a161d3e38f56dbda2743b979f0afeed2cd4f1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 12 Mar 2025 13:39:25 +0100 Subject: [PATCH] [core] Large/full refactor of `from_pretrained` (#36033) * squash everything together start to simplify inner logic Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py continue refactor fix small fixes add type hints/docstring Update modeling_utils.py remove _fast_init keep improving Update modeling_utils.py Update modeling_utils.py new first tp loading version style fix weird in-place op trigger CIs Update modeling_utils.py much clearer renaming of keys fix update Update test_modeling_common.py trigger CIs update update style Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py fix fast download first prototype remove old function remove old functions Remove unused function and move back _get_tp_registry fix tp plan registry simplify CIs Update hub.py Update modeling_utils.py simplify simplify renaming logic remove unused check add sanity check back (a test depends on it) Update modeling_utils.py finalize sound renaming logic style add forgotten check Update modeling_utils.py add key_mapping keyword style Update modeling_utils.py add comment minor updates minor change for clarity fix small prefix issue and simplify style trigger CIs typo fix Post rebase fix post rebase cleanup simplify tp typo oupsi typo correctly escape improvements based on Marc's review finalize Marc's review comments squash everything * improve * Update modeling_utils.py * Update modeling_utils.py * fix * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py * simplify * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix dtype issue * Update modeling_utils.py * style * remove test that does not make sense * style * small fixes * style * fix * cleanup after rebase * style * typo * escape * tp for task specific top modules * Update modeling_utils.py * Update modeling_utils.py * fix allocation * CIs * CIs * CIs * improve docstring * CIs * Update modeling_utils.py * fix --- src/transformers/file_utils.py | 1 - src/transformers/integrations/deepspeed.py | 4 +- src/transformers/modeling_utils.py | 2427 ++++++++--------- .../models/auto/feature_extraction_auto.py | 7 +- .../models/auto/image_processing_auto.py | 7 +- .../models/auto/processing_auto.py | 28 +- .../models/bark/processing_bark.py | 12 +- src/transformers/models/cvt/modeling_cvt.py | 2 +- .../convert_regnet_seer_10b_to_pytorch.py | 16 +- .../timm_wrapper/modeling_timm_wrapper.py | 2 +- src/transformers/utils/__init__.py | 1 - src/transformers/utils/hub.py | 452 +-- tests/test_modeling_common.py | 13 +- tests/utils/test_hub_utils.py | 69 +- tests/utils/test_modeling_utils.py | 26 +- 15 files changed, 1525 insertions(+), 1542 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 4fae91f43f..ac6b36d2db 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -71,7 +71,6 @@ from .utils import ( copy_func, default_cache_path, define_sagemaker_information, - get_file_from_repo, get_torch_version, has_file, http_user_agent, diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 1b51a53164..c896a5aa86 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -306,7 +306,7 @@ def deepspeed_config(): return None -def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): +def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False): """ Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers` tensor parallelism API. @@ -349,7 +349,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, a if child is not None: load(child, state_dict, prefix + name + ".", assign_to_params_buffers) - load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) + load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # it's safe to delete it. del state_dict diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8462fb84b1..61c86cffd0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -57,7 +57,6 @@ from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.tensor_parallel import ( SUPPORTED_TP_STYLES, shard_and_distribute_module, - translate_to_torch_parallel_style, ) from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 @@ -403,7 +402,7 @@ def dtype_byte_size(dtype): return bit_size // 8 -def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): +def check_support_param_buffer_assignment(model_to_load, state_dict): """ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first checking @@ -412,7 +411,7 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi Note: We fully disable this if we are using `deepspeed` """ - if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + if len(state_dict) == 0: return False if is_deepspeed_zero3_enabled(): @@ -427,8 +426,8 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype first_key = next(iter(model_to_load.state_dict().keys())) - if start_prefix + first_key in state_dict: - return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + if first_key in state_dict: + return state_dict[first_key].dtype == model_to_load.state_dict()[first_key].dtype # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) return False @@ -446,9 +445,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): Args: model (`torch.nn.Module`): The model in which to load the checkpoint. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. - strict (`bool`, *optional`, defaults to `True`): + strict (`bool`, *optional*, defaults to `True`): Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. - prefer_safe (`bool`, *optional*, defaults to `False`) + prefer_safe (`bool`, *optional*, defaults to `False`): If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. @@ -709,174 +708,79 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] return shared_tensors, identical -def find_submodule_and_param_name(model, long_key, start_prefix=""): - """ - A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed - from the start of the key - """ - - if len(start_prefix) > 0 and long_key.startswith(start_prefix): - long_key = ".".join(long_key.split(".")[1:]) - - split_key = long_key.split(".") - submodule = model - while len(split_key) > 1: - if hasattr(submodule, split_key[0]): - submodule = getattr(submodule, split_key[0]) - del split_key[0] +def _infer_parameter_dtype( + model: "PreTrainedModel", param_name: str, empty_param, keep_in_fp32_modules=None +) -> Union[bool, Optional[torch.dtype]]: + old_param = model.get_parameter_or_buffer(param_name) + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + # First fp32 if part of the exception list + if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): + casting_dtype = torch.float32 + # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes else: - submodule = None - break - if submodule == model: - submodule = None - return submodule, split_key[0] - - -def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): - """ - Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params. - - `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in - `bert.pooler.dense.weight` - - """ - - # dematerialize param storage for keys that are going to be replaced by state_dict, by - # putting those on the meta device - for k in loaded_state_dict_keys: - submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) - if submodule is not None: - # selectively switch to the meta device only those params/buffers that will - # be next replaced from state_dict. This a complex way to do p.to_("meta") - # since we have no in-place to_ for tensors. - new_val = getattr(submodule, param_name) - if isinstance(new_val, torch.nn.Parameter): - # isinstance returns False for Params on meta device, so switch after the check - new_val = torch.nn.Parameter(new_val.to("meta")) - else: - new_val = new_val.to("meta") - setattr(submodule, param_name, new_val) - - -def fix_tensor_type_and_device( - model, param_name, param, dtype=None, keep_in_fp32_modules=None -) -> Union[str, torch.dtype]: - # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which - # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. - # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 - - old_param = model - if "." in param_name: - pre, _ = param_name.rsplit(".", 1) - - old_param = model.get_submodule(pre) - if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): - old_param = None - - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params - # in int/uint/bool and not cast them. - param_casting_dtype = None - is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn - if param.dtype.is_floating_point and not is_param_float8_e4m3fn: - if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): - param_casting_dtype = torch.float32 - elif dtype is not None: - param_casting_dtype = dtype - elif old_param is not None: - param_casting_dtype = old_param.dtype - return old_param is not None and old_param.is_contiguous(), param_casting_dtype - else: - return False, None - - return + casting_dtype = old_param.dtype + return old_param is not None and old_param.is_contiguous(), casting_dtype @torch.no_grad() def _load_state_dict_into_meta_model( - model: torch.nn.Module, - state_dict: Dict[str, torch.Tensor], - start_prefix, - expected_keys, - device_map=None, - offload_folder=None, - offload_index=None, - state_dict_folder=None, - state_dict_index=None, - dtype=None, - hf_quantizer=None, - is_safetensors=False, - keep_in_fp32_modules=None, - unexpected_keys=None, # passing `unexpected` for cleanup from quantization items - device_mesh=None, - shard_file=None, + model: "PreTrainedModel", + state_dict: Dict, + shard_file: str, + expected_keys: List[str], + reverse_renaming_mapping: Dict[str, str], + device_map: Optional[Dict] = None, + disk_offload_folder: Optional[str] = None, + disk_offload_index: Optional[Dict] = None, + cpu_offload_folder: Optional[str] = None, + cpu_offload_index: Optional[Dict] = None, + hf_quantizer: Optional[HfQuantizer] = None, + is_safetensors: bool = False, + keep_in_fp32_modules: Optional[List[str]] = None, + unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items + device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, weights_only=True, -): - """ - This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its - params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the - params back to the normal device, but only for `loaded_state_dict_keys`. - - `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in - `bert.pooler.dense.weight` - - It also initialize tensor parallelism for each module if needed. - +) -> Tuple[Optional[Dict], Optional[Dict]]: + """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta + device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded + from `shard_file`, which is the actual state dict file on disk. + This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism. """ tensor_device = "cpu" if device_map is not None and device_map.get("", None) is not None: tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] if device_map is not None: - device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) - - file_pointer = None - bin_state_dict = None - if shard_file.endswith(".safetensors"): - file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) - elif shard_file.endswith(".bin"): - map_location = "cpu" - if ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) - bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only) - - error_msgs = [] + device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)]) is_quantized = hf_quantizer is not None + is_meta_state_dict = shard_file.endswith(".safetensors") and not is_quantized - # get full state dict - if is_quantized: - if shard_file.endswith(".safetensors"): - full_state_dict = load_state_dict(shard_file, map_location="cpu") - elif shard_file.endswith(".bin"): - full_state_dict = bin_state_dict + file_pointer = None + if is_meta_state_dict: + file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) - for serialized_param_name, empty_param in state_dict.items(): - # serialized_param_name is the raw, serialized name - # fixed_param_name is the model's equivalent - fixed_param_name, _ = model.rename_key(serialized_param_name) - - if fixed_param_name not in expected_keys: + for param_name, empty_param in state_dict.items(): + if param_name not in expected_keys: continue # we need to use serialized_param_name as file pointer is untouched - if shard_file.endswith(".safetensors"): + if is_meta_state_dict: + # This is the name of the parameter as it appears on disk file + serialized_param_name = reverse_renaming_mapping[param_name] param = file_pointer.get_slice(serialized_param_name) - elif shard_file.endswith(".gguf"): - param = empty_param # For gguf the dict is actually not empty! else: - param = bin_state_dict[serialized_param_name] + param = empty_param # It is actually not empty! - to_contiguous, param_casting_dtype = fix_tensor_type_and_device( + to_contiguous, casting_dtype = _infer_parameter_dtype( model, - param_name=fixed_param_name, - param=empty_param, - dtype=dtype, - keep_in_fp32_modules=keep_in_fp32_modules, + param_name, + empty_param, + keep_in_fp32_modules, ) if device_mesh is not None: # In this case, the param is already on the correct device! @@ -884,33 +788,33 @@ def _load_state_dict_into_meta_model( model, param, empty_param, - fixed_param_name, - param_casting_dtype, + param_name, + casting_dtype, to_contiguous, tensor_device, # the rank device_mesh, ) else: param = param[:] - if param_casting_dtype is not None: - param = param.to(param_casting_dtype) + if casting_dtype is not None: + param = param.to(casting_dtype) if to_contiguous: param = param.contiguous() if device_map is None: param_device = "cpu" else: - module_layer = re.search(device_map_regex, fixed_param_name) + module_layer = re.search(device_map_regex, param_name) if not module_layer: - raise ValueError(f"{fixed_param_name} doesn't have any device set.") + raise ValueError(f"{param_name} doesn't have any device set.") else: param_device = device_map[module_layer.group()] if param_device == "disk": if not is_safetensors: - offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index) - elif param_device == "cpu" and state_dict_index is not None: - state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index) + disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index) + elif param_device == "cpu" and cpu_offload_index is not None: + cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index) elif ( not is_quantized or (not hf_quantizer.requires_parameters_quantization) @@ -918,8 +822,8 @@ def _load_state_dict_into_meta_model( not hf_quantizer.check_quantized_param( model, param, - fixed_param_name, - full_state_dict, + param_name, + state_dict, param_device=param_device, device_map=device_map, ) @@ -927,7 +831,7 @@ def _load_state_dict_into_meta_model( ): if is_fsdp_enabled(): param_device = "cpu" if is_local_dist_rank_0() else "meta" - module, param_type = get_module_from_name(model, fixed_param_name) + module, param_type = get_module_from_name(model, param_name) module.load_state_dict( {param_type: param.to(param_device)}, strict=False, @@ -935,13 +839,13 @@ def _load_state_dict_into_meta_model( ) else: hf_quantizer.create_quantized_param( - model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys + model, param, param_name, param_device, state_dict, unexpected_keys ) # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # and then cast it to CPU to avoid excessive memory usage on each GPU # in comparison to the sharded model across GPUs. if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): - module, param_type = get_module_from_name(model, fixed_param_name) + module, param_type = get_module_from_name(model, param_name) value = getattr(module, param_type) param_to = "cpu" if is_fsdp_enabled() and not is_local_dist_rank_0(): @@ -955,7 +859,7 @@ def _load_state_dict_into_meta_model( if file_pointer is not None: file_pointer.__exit__(None, None, None) - return error_msgs, offload_index, state_dict_index + return disk_offload_index, cpu_offload_index def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: @@ -965,6 +869,560 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name +def _get_resolved_checkpoint_files( + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + subfolder: str, + variant: Optional[str], + gguf_file: Optional[str], + from_tf: bool, + from_flax: bool, + use_safetensors: bool, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict[str, str]], + local_files_only: bool, + token: Optional[Union[str, bool]], + user_agent: dict, + revision: str, + commit_hash: Optional[str], +) -> Tuple[Optional[List[str]], Optional[Dict]]: + """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the + checkpoints are sharded. + This function will download the data if necesary. + """ + is_sharded = False + + if pretrained_model_name_or_path is not None and gguf_file is None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif not use_safetensors and ( + os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) + or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif not use_safetensors and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + resolved_archive_file, revision, is_sharded = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + if not local_files_only and not is_offline_mode(): + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, + name="Thread-auto_conversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + + elif gguf_file: + # Case 1: the GGUF file is present locally + if os.path.isfile(gguf_file): + resolved_archive_file = gguf_file + # Case 2: The GGUF path is a location on the Hub + # Load from URL or cache if already cached + else: + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + resolved_archive_file = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs) + + # We now download and resolve all checkpoint files if the checkpoint is sharded + sharded_metadata = None + if is_sharded: + checkpoint_files, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + else: + checkpoint_files = [resolved_archive_file] if pretrained_model_name_or_path is not None else None + + return checkpoint_files, sharded_metadata + + +def _get_torch_dtype( + cls, + torch_dtype: Optional[Union[str, torch.dtype, Dict]], + checkpoint_files: Optional[List[str]], + config: PretrainedConfig, + sharded_metadata: Optional[Dict], + state_dict: Optional[Dict], + weights_only: bool, +) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]: + """Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the + infered dtype. We do the following: + 1. If torch_dtype is not None, we use that dtype + 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + we also may have config.torch_dtype available, but we won't rely on it till v5 + """ + dtype_orig = None + is_sharded = sharded_metadata is not None + + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif state_dict is not None: + torch_dtype = get_state_dict_dtype(state_dict) + else: + state_dict = load_state_dict( + checkpoint_files[0], map_location="meta", weights_only=weights_only + ) + torch_dtype = get_state_dict_dtype(state_dict) + logger.info( + "Since the `torch_dtype` attribute can't be found in model's config object, " + "will use torch_dtype={torch_dtype} as derived from model's weights" + ) + elif hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, torch.dtype): + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype + elif isinstance(torch_dtype, dict): + for key, curr_dtype in torch_dtype.items(): + if hasattr(config, key): + value = getattr(config, key) + value.torch_dtype = curr_dtype + # main torch dtype for modules that aren't part of any sub-config + torch_dtype = torch_dtype.get("") + config.torch_dtype = torch_dtype + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + elif torch_dtype is None: + torch_dtype = torch.float32 + else: + raise ValueError( + f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " + f"for each sub-config in composite configs, but received {torch_dtype}" + ) + + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + else: + # set fp32 as the default dtype for BC + default_dtype = str(torch.get_default_dtype()).split(".")[-1] + config.torch_dtype = default_dtype + for key in config.sub_configs.keys(): + value = getattr(config, key) + value.torch_dtype = default_dtype + + return config, torch_dtype, dtype_orig + + +def _get_device_map( + model: "PreTrainedModel", + device_map: Optional[Union[str, Dict]], + max_memory: Optional[Dict], + hf_quantizer: Optional[HfQuantizer], + torch_dtype: Optional[torch.dtype], + keep_in_fp32_modules: Optional[List[str]], +) -> Dict: + """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential']. + Otherwise, we check for any device inconsistencies in the device_map. + """ + if isinstance(device_map, str): + special_dtypes = {} + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + if keep_in_fp32_modules is not None: + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory + + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) + + elif device_map is not None: + tied_params = find_tied_parameters(model) + # check if we don't have tied param in different devices + check_tied_parameters_on_same_device(tied_params, device_map) + + return device_map + + +def _find_missing_and_unexpected_keys( + cls, + model: "PreTrainedModel", + original_checkpoint_keys: List[str], + checkpoint_keys: List[str], + loading_base_model_from_task_state_dict: bool, + hf_quantizer: Optional[HfQuantizer], + device_map: Dict, +) -> Tuple[List[str], List[str]]: + """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys + (keys found in the loaded state dict keys, but that are NOT part of the model parameters) + """ + prefix = model.base_model_prefix + + # Compute expected keys, i.e. keys that the FULL model (not model_to_load) expects + expected_keys = list(model.state_dict().keys()) + if hf_quantizer is not None: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys) + + # Adjust prefix of the keys to make them match loaded keys before removing them + missing_keys = sorted(set(expected_keys) - set(checkpoint_keys)) + unexpected_keys = set(checkpoint_keys) - set(expected_keys) + # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys + if loading_base_model_from_task_state_dict: + task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")] + unexpected_keys.update(task_specific_keys) + + # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but + # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway + model_buffers = {n for n, _ in model.named_buffers()} + unexpected_keys = sorted(unexpected_keys - model_buffers) + + # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model + # (so the buffer name has changed). Remove them in such a case + has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) + if has_inv_freq_buffers: + unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k] + + if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) + + for group in tied_params: + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] + + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) + + # Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...) + if cls._keys_to_ignore_on_load_missing is not None: + for pattern in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pattern, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pattern in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None] + + return missing_keys, unexpected_keys + + +def _find_mismatched_keys( + model_to_load: "PreTrainedModel", + state_dict: Dict, + ignore_mismatched_sizes: bool, + prefix: str, +) -> List: + """Find mismatch of shapes between the model parameters and the loaded state dict, and optionally remove the + problematic keys from `state_dict` if `ignore_mismatched_sizes` is `True`.""" + mismatched_keys = [] + if ignore_mismatched_sizes: + model_state_dict = model_to_load.state_dict() + state_dict_keys = list(state_dict.keys()) + for key in state_dict_keys: + if key in model_state_dict and state_dict[key].shape != model_state_dict[key].shape: + if state_dict[key].shape[-1] == 1 and state_dict[key].numel() * 2 == model_state_dict[key].numel(): + # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. + # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. + pass + else: + # Add prefix if we removed it before, to add the correct state dict key to the warnings + key_with_prefix = prefix + key + mismatched_keys.append((key_with_prefix, state_dict[key].shape, model_state_dict[key].shape)) + del state_dict[key] + return mismatched_keys + + class PipelineParallel(Enum): inputs: 0 outputs: 1 @@ -1281,45 +1739,6 @@ class ModuleUtilsMixin: return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) -def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, - ignore_mismatched_sizes, - prefix, -): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key, model_key in zip(original_loaded_keys, loaded_keys): - # If the checkpoint is sharded, we may not have the key here. - if checkpoint_key not in state_dict: - continue - if remove_prefix_from_model: - # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{model_key}" - elif add_prefix_to_model: - # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(model_key.split(".")[1:]) - - if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: - if ( - state_dict[checkpoint_key].shape[-1] == 1 - and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() - ): - # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. - # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights. - pass - else: - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - # TODO (joao): remove `GenerationMixin` inheritance in v4.50 class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" @@ -3423,11 +3842,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. _fast_init(`bool`, *optional*, defaults to `True`): Whether or not to disable fast initialization. @@ -3487,8 +3901,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. + tp_plan (`str`, *optional*): + A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts + `tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with + `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations. offload_folder (`str` or `os.PathLike`, *optional*): If the `device_map` contains any value `"disk"`, the folder where we will offload weights. offload_state_dict (`bool`, *optional*): @@ -3512,12 +3930,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix use_safetensors (`bool`, *optional*, defaults to `None`): Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` is not installed, it will be set to `False`. - weights_only (`bool`, *optional*, defaults to `True`): Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights. - + key_mapping (`Dict[str, str], *optional*): + A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers + architecture, but was not converted accordingly. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -3577,7 +3996,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix state_dict = kwargs.pop("state_dict", None) from_tf = kwargs.pop("from_tf", False) from_flax = kwargs.pop("from_flax", False) - resume_download = kwargs.pop("resume_download", None) + _ = kwargs.pop("resume_download", None) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -3603,12 +4022,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix adapter_name = kwargs.pop("adapter_name", "default") use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) generation_config = kwargs.pop("generation_config", None) - gguf_file = kwargs.pop("gguf_file", None) - # Cache path to the GGUF file - gguf_path = None - tp_plan = kwargs.pop("tp_plan", None) + key_mapping = kwargs.pop("key_mapping", None) + + if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None): + raise ValueError( + "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies." + ) + if tp_plan is not None and tp_plan != "auto": # TODO: we can relax this check when we support taking tp_plan from a json file, for example. raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") @@ -3684,7 +4106,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix CONFIG_NAME, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -3706,7 +4127,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, _commit_hash=commit_hash, @@ -3791,7 +4211,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -3856,334 +4275,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if low_cpu_mem_usage is None: low_cpu_mem_usage = True logger.warning("`low_cpu_mem_usage` was None, now default to True since model is quantized.") - is_quantized = hf_quantizer is not None - - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the - # index of the files. - is_sharded = False - sharded_metadata = None - # Load model - loading_info = None - - # Keep in fp32 modules - keep_in_fp32_modules = None - use_keep_in_fp32_modules = False if gguf_file is not None and hf_quantizer is not None: raise ValueError( "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub." ) - if pretrained_model_name_or_path is not None and gguf_file is None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - if from_tf and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - ): - # Load from a TF 1.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") - elif from_tf and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) - ): - # Load from a TF 2.0 checkpoint in priority if from_tf - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) - elif from_flax and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - ): - # Load from a Flax checkpoint in priority if from_flax - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) - ) - elif use_safetensors is not False and os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - is_sharded = True - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) - ): - # Load from a PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) - ) - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) - ): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) - ) - is_sharded = True - # At this stage we don't have a weight file so we will raise an error. - elif not use_safetensors and ( - os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")) - or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)) - ): - raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" - " `from_tf=True` to load this model from those weights." - ) - elif not use_safetensors and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - ): - raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" - " to load this model from those weights." - ) - elif use_safetensors: - raise EnvironmentError( - f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path}." - ) - else: - raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" - f" {pretrained_model_name_or_path}." - ) - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): - archive_file = pretrained_model_name_or_path - is_local = True - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): - if not from_tf: - raise ValueError( - f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " - "from_tf to True to load from this checkpoint." - ) - archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") - is_local = True - elif is_remote_url(pretrained_model_name_or_path): - filename = pretrained_model_name_or_path - resolved_archive_file = download_url(pretrained_model_name_or_path) - else: - # set correct filename - if from_tf: - filename = TF2_WEIGHTS_NAME - elif from_flax: - filename = FLAX_WEIGHTS_NAME - elif use_safetensors is not False: - filename = _add_variant(SAFE_WEIGHTS_NAME, variant) - else: - filename = _add_variant(WEIGHTS_NAME, variant) + checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + variant=variant, + gguf_file=gguf_file, + from_tf=from_tf, + from_flax=from_flax, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + commit_hash=commit_hash, + ) - try: - # Load from URL or cache if already cached - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "token": token, - "user_agent": user_agent, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, - ) - if resolved_archive_file is not None: - is_sharded = True - elif use_safetensors: - if revision == "main": - resolved_archive_file, revision, is_sharded = auto_conversion( - pretrained_model_name_or_path, **cached_file_kwargs - ) - cached_file_kwargs["revision"] = revision - if resolved_archive_file is None: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " - "and thus cannot be loaded with `safetensors`. Please make sure that the model has " - "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." - ) - else: - # This repo has no safetensors file of any kind, we switch to PyTorch. - filename = _add_variant(WEIGHTS_NAME, variant) - resolved_archive_file = cached_file( - pretrained_model_name_or_path, filename, **cached_file_kwargs - ) - if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, - ) - if resolved_archive_file is not None: - is_sharded = True - if not local_files_only and not is_offline_mode(): - if resolved_archive_file is not None: - if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: - # If the PyTorch file was found, check if there is a safetensors file on the repository - # If there is no safetensors file on the repositories, start an auto conversion - safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "resume_download": resume_download, - "local_files_only": local_files_only, - "user_agent": user_agent, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - **has_file_kwargs, - } - if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - Thread( - target=auto_conversion, - args=(pretrained_model_name_or_path,), - kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, - name="Thread-auto_conversion", - ).start() - else: - # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. - # We try those to give a helpful error message. - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - "cache_dir": cache_dir, - "local_files_only": local_files_only, - } - if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." - " Use `from_tf=True` to load this model from those weights." - ) - elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" - " `from_flax=True` to load this model from those weights." - ) - elif variant is not None and has_file( - pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs - ): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" - f" {variant}. Use `variant=None` to load this model from those weights." - ) - else: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." - ) - - except EnvironmentError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted - # to the original exception. - raise - except Exception as e: - # For any other exception, we throw a generic error. - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - " from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," - f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." - ) from e - - if is_local: - logger.info(f"loading weights file {archive_file}") - resolved_archive_file = archive_file - else: - logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - elif gguf_file: - from .modeling_gguf_pytorch_utils import load_gguf_checkpoint - - # Case 1: the GGUF file is present locally - if os.path.isfile(gguf_file): - gguf_path = gguf_file - # Case 2: The GGUF path is a location on the Hub - # Load from URL or cache if already cached - else: - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "token": token, - "user_agent": user_agent, - "revision": revision, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - } - - gguf_path = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs) - - # we need a dummy model to help rename state_dict - with torch.device("meta"): - dummy_model = cls(config) - state_dict = load_gguf_checkpoint(gguf_path, return_tensors=True, model_to_load=dummy_model)["tensors"] - - resolved_archive_file = None - is_sharded = False - else: - resolved_archive_file = None - - # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _commit_hash=commit_hash, - ) + is_sharded = sharded_metadata is not None + is_quantized = hf_quantizer is not None + is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None if ( is_safetensors_available() - and isinstance(resolved_archive_file, str) - and resolved_archive_file.endswith(".safetensors") + and is_from_file + and not is_sharded + and checkpoint_files[0].endswith(".safetensors") ): - with safe_open(resolved_archive_file, framework="pt") as f: + with safe_open(checkpoint_files[0], framework="pt") as f: metadata = f.metadata() if metadata is None: @@ -4207,96 +4333,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix from_pt = not (from_tf | from_flax) - # load pt weights early so that we know which dtype to init the model under - if from_pt: - if not is_sharded and state_dict is None: - # Time to load the checkpoint - state_dict = load_state_dict(resolved_archive_file, map_location="meta", weights_only=weights_only) + if gguf_file: + from .modeling_gguf_pytorch_utils import load_gguf_checkpoint - # set dtype to instantiate the model under: - # 1. If torch_dtype is not None, we use that dtype - # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first - # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype - # we also may have config.torch_dtype available, but we won't rely on it till v5 - dtype_orig = None + # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was + # passed directly as a kwarg from now on + with torch.device("meta"): + dummy_model = cls(config) + state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[ + "tensors" + ] + # Force it if is not already the case + low_cpu_mem_usage = True - if torch_dtype is not None: - if isinstance(torch_dtype, str): - if torch_dtype == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None: - torch_dtype = config.torch_dtype - logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") - else: - if is_sharded and "dtype" in sharded_metadata: - torch_dtype = sharded_metadata["dtype"] - elif not is_sharded: - torch_dtype = get_state_dict_dtype(state_dict) - else: - one_state_dict = load_state_dict( - resolved_archive_file[0], map_location="meta", weights_only=weights_only - ) - torch_dtype = get_state_dict_dtype(one_state_dict) - del one_state_dict # free CPU memory - logger.info( - "Since the `torch_dtype` attribute can't be found in model's config object, " - "will use torch_dtype={torch_dtype} as derived from model's weights" - ) - elif hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) - for sub_config_key in config.sub_configs.keys(): - sub_config = getattr(config, sub_config_key) - sub_config.torch_dtype = torch_dtype - elif isinstance(torch_dtype, torch.dtype): - for sub_config_key in config.sub_configs.keys(): - sub_config = getattr(config, sub_config_key) - sub_config.torch_dtype = torch_dtype - elif isinstance(torch_dtype, dict): - for key, curr_dtype in torch_dtype.items(): - if hasattr(config, key): - value = getattr(config, key) - value.torch_dtype = curr_dtype - # main torch dtype for modules that aren't part of any sub-config - torch_dtype = torch_dtype.get("") - config.torch_dtype = torch_dtype - if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): - torch_dtype = getattr(torch, torch_dtype) - elif torch_dtype is None: - torch_dtype = torch.float32 - else: - raise ValueError( - f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` " - f"for each sub-config in composite configs, but received {torch_dtype}" - ) - - dtype_orig = cls._set_default_torch_dtype(torch_dtype) - else: - # set fp32 as the default dtype for BC - default_dtype = str(torch.get_default_dtype()).split(".")[-1] - config.torch_dtype = default_dtype - for key in config.sub_configs.keys(): - value = getattr(config, key) - value.torch_dtype = default_dtype - - # Check if `_keep_in_fp32_modules` is not None - use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( - (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + # Find the correct dtype based on current state + config, torch_dtype, dtype_orig = _get_torch_dtype( + cls, torch_dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only ) - if is_sharded: - loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] - else: - loaded_state_dict_keys = list(state_dict.keys()) - if ( - gguf_path is None - and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())) - and pretrained_model_name_or_path is not None - ): - # In case some weights need to be kept in float32 and accelerate is not installed, - # we later on want to take the path where state_dict is not None, that is the one - # that do not require accelerate. - state_dict = None - config.name_or_path = pretrained_model_name_or_path # Instantiate model. @@ -4312,6 +4367,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + # Make sure to tie the weights correctly + model.tie_weights() + + # Last check for tp if device_mesh is not None and not model.supports_tp_plan: if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None: raise NotImplementedError("This model does not have a tensor parallel plan.") @@ -4319,13 +4378,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # make sure we use the model's config since the __init__ call might have copied it config = model.config - # Check first if we are `from_pt` - if use_keep_in_fp32_modules: + # Find fp32 modules if needed + keep_in_fp32_modules = None + if model._keep_in_fp32_modules is not None: if is_accelerate_available() and not is_deepspeed_zero3_enabled(): low_cpu_mem_usage = True - keep_in_fp32_modules = model._keep_in_fp32_modules - else: - keep_in_fp32_modules = [] + keep_in_fp32_modules = model._keep_in_fp32_modules if len(model._keep_in_fp32_modules) > 0 else None if hf_quantizer is not None: hf_quantizer.preprocess_model( @@ -4338,104 +4396,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # remain a single source of truth config._pre_quantization_dtype = torch_dtype - if isinstance(device_map, str): - special_dtypes = {} - - if hf_quantizer is not None: - special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) - - special_dtypes.update( - { - name: torch.float32 - for name, _ in model.named_parameters() - if any(m in name for m in keep_in_fp32_modules) - } - ) - - target_dtype = torch_dtype - - if hf_quantizer is not None: - target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) - - no_split_modules = model._get_no_split_modules(device_map) - if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: - raise ValueError( - "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " - "'sequential'." - ) - - device_map_kwargs = {"no_split_module_classes": no_split_modules} - if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: - device_map_kwargs["special_dtypes"] = special_dtypes - elif len(special_dtypes) > 0: - logger.warning( - "This model has some weights that should be kept in higher precision, you need to upgrade " - "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." - ) - if device_map != "sequential": - max_memory = get_balanced_memory( - model, - dtype=target_dtype, - low_zero=(device_map == "balanced_low_0"), - max_memory=max_memory, - **device_map_kwargs, - ) - else: - max_memory = get_max_memory(max_memory) - if hf_quantizer is not None: - max_memory = hf_quantizer.adjust_max_memory(max_memory) - device_map_kwargs["max_memory"] = max_memory - - # Make sure tied weights are tied before creating the device map. - model.tie_weights() - device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) - - if hf_quantizer is not None: - hf_quantizer.validate_environment(device_map=device_map) - - elif device_map is not None: - model.tie_weights() - tied_params = find_tied_parameters(model) - # check if we don't have tied param in different devices - check_tied_parameters_on_same_device(tied_params, device_map) - - if gguf_path and device_map is not None and "disk" in device_map.values(): - raise RuntimeError( - "One or more modules is configured to be mapped to disk. Disk offload is not supported for models " - "loaded from GGUF files." + # Prepare the full device map + if device_map is not None: + device_map = _get_device_map( + model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules ) + # Finalize model weight initialization if from_tf: - if resolved_archive_file.endswith(".index"): - # Load from a TensorFlow 1.X checkpoint - provided by original authors - model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' - else: - # Load from our TensorFlow 2.0 checkpoints - try: - from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model - - model, loading_info = load_tf2_checkpoint_in_pytorch_model( - model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True - ) - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." - " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" - " instructions." - ) - raise + model, loading_info = cls._load_from_tf(model, config, checkpoint_files) elif from_flax: - try: - from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model - - model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) - except ImportError: - logger.error( - "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" - " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" - " installation instructions." - ) - raise + model = cls._load_from_flax(model, checkpoint_files) elif from_pt: # restore default dtype if dtype_orig is not None: @@ -4451,22 +4422,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) = cls._load_pretrained_model( model, state_dict, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file or gguf_file, + checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, sharded_metadata=sharded_metadata, - _fast_init=_fast_init, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device_map, - offload_folder=offload_folder, + disk_offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, - gguf_path=gguf_path, - weights_only=weights_only, device_mesh=device_mesh, + key_mapping=key_mapping, + weights_only=weights_only, + _fast_init=_fast_init, ) # make sure token embedding weights are still tied if needed @@ -4485,7 +4455,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix pretrained_model_name_or_path, cache_dir=cache_dir, force_download=force_download, - resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, @@ -4530,13 +4499,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) - # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is - # not part of the state_dict (persistent=False) - if device_mesh is not None: - for buffer in model.buffers(): - if buffer.device != tp_device: - buffer.data = buffer.to(tp_device) - if hf_quantizer is not None: hf_quantizer.postprocess_model(model, config=config) model.hf_quantizer = hf_quantizer @@ -4550,19 +4512,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) if output_loading_info: - if loading_info is None: + if from_pt: loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } + elif from_flax: + loading_info = None return model, loading_info return model @staticmethod - def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: + def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) # This rename is logged. @@ -4587,38 +4551,48 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return key, False - def rename_key(self, key): + def _get_key_renaming_mapping( + self, + checkpoint_keys: List[str], + key_mapping: Optional[Dict[str, str]] = None, + loading_base_model_from_task_state_dict: bool = False, + loading_task_model_from_base_state_dict: bool = False, + ): """ - When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra - prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute. - - But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration? - In that what you actually want is to cut whatever is left of the key. - """ - new_key = key - if len(self.base_model_prefix) > 0: - if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix): - new_key = ".".join(key.split(".")[1:]) - elif ( - hasattr(self, self.base_model_prefix) - and not key.startswith(self.base_model_prefix) - and key not in self.expected_keys - ): - new_key = f"{self.base_model_prefix}.{key}" - - new_key, has_changed = self._fix_state_dict_key_on_load(new_key) - return new_key, has_changed - - def _fix_state_dict_keys_on_load(self, state_dict): - """Fixes state dict keys by replacing legacy parameter names with their modern equivalents. - Logs if any parameters have been renamed. + Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model + that we are loading expects. This is the single entry point for key renaming that will be used during + loading. + Log if any parameters have been renamed. """ + prefix = self.base_model_prefix + _prefix = f"{prefix}." renamed_keys = {} - state_dict_keys = list(state_dict.keys()) - for key in state_dict_keys: - new_key, has_changed = self.rename_key(key) - state_dict[new_key] = state_dict.pop(key) + key_renaming_mapping = {} + for key in checkpoint_keys: + # Class specific rename + new_key, has_changed = self._fix_state_dict_key_on_load(key) + + # Optionally map the key according to `key_mapping` + if key_mapping is not None: + for pattern, replacement in key_mapping.items(): + new_key, n_replace = re.subn(pattern, replacement, new_key) + # Early exit of the loop + if n_replace > 0: + has_changed = True + break + + # In this case, we need to add the prefix to the keys, to match them to the expected keys + if loading_task_model_from_base_state_dict: + new_key = ".".join([prefix, new_key]) + # In this case we need to remove the prefix from the key to match them to the expected keys, and use + # only the keys starting with the prefix + elif loading_base_model_from_task_state_dict: + if not new_key.startswith(_prefix): + continue + new_key = new_key[len(_prefix) :] + + key_renaming_mapping[key] = new_key # track gamma/beta rename for logging if has_changed: @@ -4635,7 +4609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." logger.info_once(warning_msg) - return state_dict + return key_renaming_mapping @staticmethod def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: @@ -4655,433 +4629,310 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix @classmethod def _load_pretrained_model( cls, - model, - state_dict, - loaded_keys, - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=False, - sharded_metadata=None, - _fast_init=True, - low_cpu_mem_usage=False, - device_map=None, - offload_folder=None, - offload_state_dict=None, - dtype=None, - hf_quantizer=None, - keep_in_fp32_modules=None, - gguf_path=None, - weights_only=True, - device_mesh=None, + model: "PreTrainedModel", + state_dict: Optional[Dict], + checkpoint_files: Optional[List[str]], + pretrained_model_name_or_path: Optional[str], + ignore_mismatched_sizes: bool = False, + sharded_metadata: Optional[Dict] = None, + low_cpu_mem_usage: bool = False, + device_map: Optional[Dict] = None, + disk_offload_folder: Optional[str] = None, + offload_state_dict: Optional[bool] = None, + dtype: Optional[torch.dtype] = None, + hf_quantizer: Optional[HfQuantizer] = None, + keep_in_fp32_modules: Optional[List[str]] = None, + device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, + key_mapping: Optional[Dict[str, str]] = None, + weights_only: bool = True, + _fast_init: bool = True, ): - is_safetensors = False + # Useful flags is_quantized = hf_quantizer is not None - state_dict_folder = None - state_dict_index = None - if device_map is not None and "disk" in device_map.values(): - archive_file = ( - resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + # Get all the keys of the state dicts that we have to initialize the model + if sharded_metadata is not None: + original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"] + elif state_dict is not None: + original_checkpoint_keys = list(state_dict.keys()) + else: + original_checkpoint_keys = list( + load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys() ) - is_safetensors = archive_file is not None and archive_file.endswith(".safetensors") - if offload_folder is None and not is_safetensors: + + # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture + prefix = model.base_model_prefix + _prefix = f"{prefix}." + has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False + expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False + loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module + loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module + + # Find the key names that the model expects from the serialized keys + key_renaming_mapping = model._get_key_renaming_mapping( + original_checkpoint_keys, + key_mapping, + loading_base_model_from_task_state_dict, + loading_task_model_from_base_state_dict, + ) + checkpoint_keys = list(key_renaming_mapping.values()) + + # Find missing and unexpected keys from the state dict + missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( + cls, + model, + original_checkpoint_keys, + checkpoint_keys, + loading_base_model_from_task_state_dict, + hf_quantizer, + device_map, + ) + + # Move missing keys back to cpu from meta device (because they won't be moved when loading the weights as + # they are not in the loaded state dict) + if low_cpu_mem_usage: + model._move_missing_keys_from_meta_to_cpu(missing_keys, unexpected_keys, dtype, hf_quantizer) + # In this case we also need to move everything back + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: + for key, param in model.state_dict().items(): + if param.device == torch.device("meta"): + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype)) + + # correctly initialize the missing keys if it was skipped before + if _fast_init or low_cpu_mem_usage: + model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized) + + # Set some modules to fp32 if needed + if keep_in_fp32_modules is not None: + keep_in_fp32_modules = re.compile("|".join([re.escape(module) for module in keep_in_fp32_modules])) + for name, param in model.named_parameters(): + if keep_in_fp32_modules.search(name): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + # Make sure we are able to load base models as well as derived models (specific task models, with heads) + model_to_load = model + # In this case, we load a ForTaskModel with keys from a BaseModel -> only load keys to the BaseModel + if loading_task_model_from_base_state_dict: + model_to_load = getattr(model, prefix) + # Here we need to remove the prefix we added to correctly find missing/unexpected keys, as we will load + # in the submodule + key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()} + checkpoint_keys = list(key_renaming_mapping.values()) + # We need to update the device map as well + if device_map is not None: + device_map = {k[len(_prefix) :] if k.startswith(_prefix) else k: v for k, v in device_map.items()} + # small sanity check: the base model should not contain task-specific head keys + task_specific_expected_keys = [s for s in model.state_dict().keys() if not s.startswith(_prefix)] + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any( + key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys + ): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + + # Get reverse key mapping + reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()} + + is_offloaded_safetensors = False + # This offload index if for params explicitly on the "disk" in the device_map + disk_offload_index = None + disk_only_shard_files = [] + # Prepare parameters offloading if needed + if device_map is not None and "disk" in device_map.values(): + if offload_state_dict is None: + offload_state_dict = True + if disk_offload_folder is not None: + os.makedirs(disk_offload_folder, exist_ok=True) + is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") + if disk_offload_folder is None and not is_offloaded_safetensors: raise ValueError( "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " offers the weights in this format." ) - if offload_folder is not None: - os.makedirs(offload_folder, exist_ok=True) - if offload_state_dict is None: - offload_state_dict = True - - is_sharded_safetensors = is_safetensors and sharded_metadata is not None - - # tie the model weights before retrieving the state_dict - model.tie_weights() - - # Retrieve missing & unexpected_keys - model_state_dict = model.state_dict() - expected_keys = list(model_state_dict.keys()) - prefix = model.base_model_prefix - - if hf_quantizer is not None: - expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) - - original_loaded_keys = loaded_keys - loaded_keys = [model._fix_state_dict_key_on_load(key)[0] for key in loaded_keys] - - if len(prefix) > 0: - has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) - expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) - else: - has_prefix_module = False - expects_prefix_module = False - - # key re-naming operations are never done on the keys - # that are loaded, but always on the keys of the newly initialized model - remove_prefix_from_model = not has_prefix_module and expects_prefix_module - add_prefix_to_model = has_prefix_module and not expects_prefix_module - - if remove_prefix_from_model: - _prefix = f"{prefix}." - expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] - expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] - elif add_prefix_to_model: - expected_keys = [".".join([prefix, s]) for s in expected_keys] - - missing_keys = sorted(set(expected_keys) - set(loaded_keys)) - unexpected_keys = set(loaded_keys) - set(expected_keys) - - # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model - # buffers - model_buffers = {n for n, _ in model.named_buffers()} - if remove_prefix_from_model: - model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} - elif add_prefix_to_model: - model_buffers = {".".join([prefix, key]) for key in model_buffers} - unexpected_keys = sorted(unexpected_keys - model_buffers) - - # Clean up buffer for `inv-freq` because RoPE embedding moved under base model (https://github.com/huggingface/transformers/pull/34858) - has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers) - if has_inv_freq_buffers: - unexpected_keys = {k for k in unexpected_keys if "rotary_emb.inv_freq" not in k} - - model.tie_weights() - if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): - ptrs = collections.defaultdict(list) - for name, tensor in model.state_dict().items(): - id_tensor = id_tensor_storage(tensor) - ptrs[id_tensor].append(name) - - # These are all the pointers of shared tensors. - tied_params = [names for _, names in ptrs.items() if len(names) > 1] - else: - # id function doesn't work for meta tensor so we need this function - tied_params = find_tied_parameters(model) - - for group in tied_params: - if remove_prefix_from_model: - group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] - elif add_prefix_to_model: - group = [".".join([prefix, key]) for key in group] - missing_in_group = [k for k in missing_keys if k in group] - if len(missing_in_group) > 0 and len(missing_in_group) < len(group): - missing_keys = [k for k in missing_keys if k not in missing_in_group] - - # Some models may have keys that are not in the state by design, removing them before needlessly warning - # the user. - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] - if hf_quantizer is not None: - missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) - unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix) - - # retrieve weights on meta device and put them back on CPU. - # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step - if low_cpu_mem_usage: - for key in missing_keys: - if key in model_state_dict: - key = key - elif f"{prefix}.{key}" in model_state_dict: - key = f"{prefix}.{key}" - elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in model_state_dict: - key = ".".join(key.split(".")[1:]) - param = model_state_dict[key] - - # upcast in fp32 if any - target_dtype = dtype - if ( - keep_in_fp32_modules is not None - and dtype == torch.float16 - and any( - module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules - ) - ): - target_dtype = torch.float32 - - if param.device == torch.device("meta"): - value = torch.empty(*param.size(), dtype=target_dtype) - if ( - not is_quantized - or (getattr(hf_quantizer, "requires_parameters_quantization", False)) - or not hf_quantizer.check_quantized_param( - model, param_value=value, param_name=key, state_dict={} - ) - ): - set_module_tensor_to_device(model, key, "cpu", value) - else: - hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) - - # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. - if _fast_init: - if not ignore_mismatched_sizes: - if remove_prefix_from_model: - _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] - elif add_prefix_to_model: - _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + if is_offloaded_safetensors: + param_device_map = expand_device_map(device_map, checkpoint_keys) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + weight_map = {p: checkpoint_files[0] for p in checkpoint_keys} else: - _loaded_keys = loaded_keys - not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) - # If we're about to tie the output embeds to the input embeds we don't need to init them - if ( - hasattr(model.config.get_text_config(decoder=True), "tie_word_embeddings") - and model.config.get_text_config(decoder=True).tie_word_embeddings - ): - output_embeddings = model.get_output_embeddings() - if output_embeddings is not None: - # Still need to initialize if there is a bias term since biases are not tied. - if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: - output_embeddings._is_hf_initialized = True + folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) + # Fix the weight map keys according to the key mapping + weight_map = { + key_renaming_mapping[k]: v + for k, v in sharded_metadata["weight_map"].items() + if k in key_renaming_mapping + } + weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} + # Find potential checkpoints containing only offloaded weights + disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) + disk_offload_index = { + name: { + "safetensors_file": file, + "weight_name": reverse_key_renaming_mapping[name], + "dtype": str_dtype, + } + for name, file in weight_map.items() + if param_device_map[name] == "disk" + } else: - not_initialized_submodules = dict(model.named_modules()) - # This will only initialize submodules that are not marked as initialized by the line above. - if is_deepspeed_zero3_enabled() and not is_quantized: - import deepspeed + disk_offload_index = {} - not_initialized_parameters = list( - set( - itertools.chain.from_iterable( - submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() - ) - ) - ) - with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): - model.apply(model._initialize_weights) - else: - model.apply(model._initialize_weights) + # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map + # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size, + # i.e. 1x to load it, and 1x to copy it to model + cpu_offload_folder = None + cpu_offload_index = None + if offload_state_dict: + cpu_offload_folder = tempfile.mkdtemp() + cpu_offload_index = {} - # Set some modules to fp32 if any - if keep_in_fp32_modules == []: - keep_in_fp32_modules = None - if keep_in_fp32_modules is not None: - keep_in_fp32_modules = re.compile("|".join(keep_in_fp32_modules)) - for name, param in model.named_parameters(): - if keep_in_fp32_modules.search(name): - # param = param.to(torch.float32) does not work here as only in the local scope. - param.data = param.data.to(torch.float32) # TODO @Cyrilvallez: we seem to do this twice + # For nice tqdm bars + if checkpoint_files is not None and len(checkpoint_files) > 1: + checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") + # To be able to iterate, even if we don't use it if the state_dict is already provided + elif state_dict is not None: + checkpoint_files = [""] - # Make sure we are able to load base models as well as derived models (with heads) - start_prefix = "" - model_to_load = model - if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: - start_prefix = cls.base_model_prefix + "." - if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: - model_to_load = getattr(model, cls.base_model_prefix) - base_model_expected_keys = list(model_to_load.state_dict().keys()) - if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): - raise ValueError( - "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " - "properly saved?" - ) - if device_map is not None: - device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} + # Compute expected model keys + expected_keys = list(model_to_load.state_dict().keys()) + if hf_quantizer is not None: + expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys) - if resolved_archive_file is not None: - folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) - else: - folder = None - - model_to_load.expected_keys = expected_keys - if device_map is not None: - expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) - if hf_quantizer is None: - caching_allocator_warmup(model_to_load, expanded_device_map, dtype) - - if device_map is not None and is_safetensors: - param_device_map = expanded_device_map - str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" - if sharded_metadata is None: - archive_file = ( - resolved_archive_file[0] - if isinstance(resolved_archive_file, (list, tuple)) - else resolved_archive_file - ) - weight_map = {p: archive_file for p in original_loaded_keys} - else: - weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} - offload_index = { - p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} - for p, f in weight_map.items() - if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" - } - else: - offload_index = None + # Warmup cuda to load the weights much faster on devices + if device_map is not None and hf_quantizer is None: + expanded_device_map = expand_device_map(device_map, expected_keys) + caching_allocator_warmup(model_to_load, expanded_device_map, dtype) error_msgs = [] - if state_dict is not None: - # Whole checkpoint - mismatched_keys = _find_mismatched_keys( + mismatched_keys = [] + # Iterate on all the shards to load the weights + for shard_file in checkpoint_files: + # Skip the load for shards that only contain disk-offloaded weights + if shard_file in disk_only_shard_files: + continue + + map_location = "cpu" + if low_cpu_mem_usage: + if shard_file.endswith(".safetensors") and not is_quantized: + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + + # If shard_file is""", we use the existing state_dict instead of loading it + if shard_file != "": + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + # Fix the key names + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + model_to_load, state_dict, - model_state_dict, - loaded_keys, - original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, ignore_mismatched_sizes, - prefix, + prefix if loading_base_model_from_task_state_dict else "", ) - # For GGUF models `state_dict` is never set to None as the state dict is always small - if gguf_path or low_cpu_mem_usage and is_safetensors: - error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - start_prefix, - expected_keys, - device_map=device_map, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_folder=state_dict_folder, - state_dict_index=state_dict_index, - dtype=dtype, - hf_quantizer=hf_quantizer, - is_safetensors=is_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - shard_file=resolved_archive_file, - weights_only=weights_only, - ) - else: - # We need to read the state dict as it is meta otherwise - if resolved_archive_file is not None: - state_dict = load_state_dict(resolved_archive_file, map_location="cpu") - assign_to_params_buffers = check_support_param_buffer_assignment( - model_to_load, state_dict, start_prefix - ) - # at this point the state dict should be on cpu, we don't need to actually read it - mismatched_names = [name for name, _, _ in mismatched_keys] - fixed_state_dict = {k: v for k, v in state_dict.items() if k not in mismatched_names} - fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(fixed_state_dict) - - if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model( - model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers + if low_cpu_mem_usage and shard_file is not None: + # Skip it with fsdp on ranks other than 0 + if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + weights_only=weights_only, ) - else: - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) - else: - # This should always be a list but, just to be sure. - if not isinstance(resolved_archive_file, list): - resolved_archive_file = [resolved_archive_file] - - error_msgs = [] - mismatched_keys = [] - if not is_safetensors: - offload_index = {} if device_map is not None and "disk" in device_map.values() else None - if offload_state_dict: - state_dict_folder = tempfile.mkdtemp() - state_dict_index = {} else: - state_dict_folder = None - state_dict_index = None - - if is_sharded_safetensors: - disk_only_shard_files = get_disk_only_shard_files( - device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix - ) - disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] - else: - disk_only_shard_files = [] - - if len(resolved_archive_file) > 1: - resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") - assign_to_params_buffers = None - for shard_file in resolved_archive_file: - # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. - if shard_file in disk_only_shard_files: - continue - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only - ) - - # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys += _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - original_loaded_keys, - add_prefix_to_model, - remove_prefix_from_model, - ignore_mismatched_sizes, - prefix, - ) - if low_cpu_mem_usage: - if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: - for key, param in model_to_load.state_dict().items(): - if param.device == torch.device("meta"): - set_module_tensor_to_device( - model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) - ) - else: - new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - prefix, - expected_keys, - device_map=device_map, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_folder=state_dict_folder, - state_dict_index=state_dict_index, - dtype=dtype, - hf_quantizer=hf_quantizer, - is_safetensors=is_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - shard_file=shard_file, - weights_only=weights_only, - ) - error_msgs += new_error_msgs + assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) else: - state_dict = load_state_dict(shard_file, map_location="cpu", weights_only=weights_only) - # Sharded checkpoint or whole but low_cpu_mem_usage==True - if assign_to_params_buffers is None: - assign_to_params_buffers = check_support_param_buffer_assignment( - model_to_load, state_dict, start_prefix - ) - fixed_state_dict = model_to_load._fix_state_dict_keys_on_load(state_dict) - if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model( - model_to_load, fixed_state_dict, start_prefix, assign_to_params_buffers - ) - else: - model_to_load.load_state_dict(fixed_state_dict, strict=False, assign=assign_to_params_buffers) - # force memory release - del state_dict - gc.collect() + model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) - if offload_index is not None and len(offload_index) > 0: - if model != model_to_load: - # We need to add the prefix of the base model - prefix = cls.base_model_prefix - if not is_safetensors: - for weight_name in offload_index: - shutil.move( - os.path.join(offload_folder, f"{weight_name}.dat"), - os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), - ) - offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} - if not is_safetensors: - save_offload_index(offload_index, offload_folder) - offload_index = None + # force memory release + del state_dict + gc.collect() - if offload_state_dict: - # Load back temporarily offloaded state dict - load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) - shutil.rmtree(state_dict_folder) + # Adjust offloaded weights name and save if needed + if disk_offload_index is not None and len(disk_offload_index) > 0: + if loading_task_model_from_base_state_dict: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + if not is_offloaded_safetensors: + for weight_name in disk_offload_index: + shutil.move( + os.path.join(disk_offload_folder, f"{weight_name}.dat"), + os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"), + ) + disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()} + if not is_offloaded_safetensors: + save_offload_index(disk_offload_index, disk_offload_folder) + disk_offload_index = None + + # one-at-a-time param loading for the cpu offloaded params + if offload_state_dict: + # Load back temporarily offloaded state dict + load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder) + shutil.rmtree(cpu_offload_folder) if hf_quantizer is not None: missing_keys = hf_quantizer.update_missing_keys_after_loading(model_to_load, missing_keys, prefix) + # Post-processing for tensor parallelism + if device_mesh is not None: + # When using TP, the device map is a single device for all parameters + tp_device = list(device_map.values())[0] + # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is + # not part of the state_dict (persistent=False) + for buffer in model.buffers(): + if buffer.device != tp_device: + buffer.data = buffer.to(tp_device) + + # In this case, the top-most task module weights were not moved to device and parallelized as they + # were not part of the loaded weights: do it now + if loading_task_model_from_base_state_dict: + parameters_to_initialize = { + name: param for name, param in model.named_parameters() if not name.startswith(prefix) + } + for name, param in parameters_to_initialize.items(): + # First move data to correct + to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_modules) + shard_and_distribute_module( + model, + param.to(tp_device), + param, + name, + casting_dtype, + to_contiguous, + tp_device.index, + device_mesh, + ) + + # All potential warnings/infos if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: @@ -5089,7 +4940,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - if len(unexpected_keys) > 0: archs = [] if model.config.architectures is None else model.config.architectures warner = logger.warning if model.__class__.__name__ in archs else logger.info @@ -5131,7 +4981,45 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " to use it for predictions and inference." ) - return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs + + @classmethod + def _load_from_tf(cls, model, config, checkpoint_files): + if checkpoint_files[0].endswith(".index"): + # Load from a TensorFlow 1.X checkpoint - provided by original authors + model = cls.load_tf_weights(model, config, checkpoint_files[0][:-6]) # Remove the '.index' + loading_info = None + else: + # Load from our TensorFlow 2.0 checkpoints + try: + from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model + + model, loading_info = load_tf2_checkpoint_in_pytorch_model( + model, checkpoint_files[0], allow_missing_keys=True, output_loading_info=True + ) + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." + " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" + " instructions." + ) + raise + return model, loading_info + + @classmethod + def _load_from_flax(cls, model, checkpoint_files): + try: + from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, checkpoint_files[0]) + except ImportError: + logger.error( + "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" + " installation instructions." + ) + raise + return model def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): module_keys = {".".join(key.split(".")[:-1]) for key in names} @@ -5156,47 +5044,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return retrieved_modules - @staticmethod - def _load_pretrained_model_low_mem( - model, - loaded_state_dict_keys, - resolved_archive_file, - start_prefix="", - hf_quantizer=None, - pretrained_model_name_or_path=None, - weights_only=True, - ): - """ - This is an experimental function that loads the model using ~1.x model size CPU memory - - Before you call it do: - - 1. save which state_dict keys are available - 2. drop state_dict before model is created, since the latter takes 1x model size memory - - Here then we continue: - - 3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict - 4. load state_dict 2nd time - 5. replace the params/buffers from the state_dict - - Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. To - handle bitsandbytes, needs non-empty hf_quantizer argument. - """ - - _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) - state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only) - expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys - fixed_state_dict = model._fix_state_dict_keys_on_load(state_dict) - error_msgs = _load_state_dict_into_meta_model( - model, - fixed_state_dict, - start_prefix, - expected_keys=expected_keys, - hf_quantizer=hf_quantizer, - ) - return error_msgs - @classmethod def register_for_auto_class(cls, auto_class="AutoModel"): """ @@ -5320,52 +5167,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return True return False - def tensor_parallel(self, device_mesh): - """ - Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model - was already loaded in memory, note however that this means that each process will first initialize the whole model, - then parallelize it across devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time. - - Calling `from_pretrained(..., tp_plan="auto")` is preferred, and will parallelize module-by-module during initialization, - so that the expected per-device memory spike at loading time is not larger than the final model size on each device. - Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model - was already loaded in memory, note however that this means that each process will first initialize the whole model, - then parallelize it across devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time. - - Args: - device_mesh (`torch.distributed.DeviceMesh`): - The device mesh to use for tensor parallelism. - """ - if not is_torch_greater_or_equal("2.5"): - raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") - - # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. - # No op if `_tp_plan` attribute does not exist under the module. - # This is a helper function to be used with `model.apply` to recursively - # parallelize a model. - def tplize(mod: torch.nn.Module) -> None: - tp_plan = getattr(mod, "_tp_plan", None) - if tp_plan is None: - return - logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") - # In model configs, we use a neutral type (string) to specify - # parallel styles, here we translate them into torch TP types. - # Using tree_map because `tp_plan` is a dict. - tp_plan = torch.utils._pytree.tree_map( - translate_to_torch_parallel_style, - tp_plan, - ) - # Apply TP to current module. - torch.distributed.tensor.parallel.parallelize_module( - mod, - device_mesh=device_mesh, - parallelize_plan=tp_plan, - ) - - # `apply` is a native method of `nn.Module` that recursively applies a - # function to every submodule. - self.apply(tplize) - @property def supports_pp_plan(self): if self._pp_plan is not None: @@ -5413,6 +5214,91 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def is_backend_compatible(cls): return cls._supports_attention_backend + def _move_missing_keys_from_meta_to_cpu( + self, + missing_keys: List[str], + unexpected_keys: List[str], + dtype: Optional[torch.dtype], + hf_quantizer: Optional[HfQuantizer], + ) -> "PreTrainedModel": + """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back + from meta device to cpu. + """ + is_quantized = hf_quantizer is not None + + model_state_dict = self.state_dict() + for key in missing_keys: + param = model_state_dict[key] + if param.device == torch.device("meta"): + # upcast in fp32 if any + target_dtype = dtype + value = torch.empty(*param.size(), dtype=target_dtype) + if ( + not is_quantized + or (getattr(hf_quantizer, "requires_parameters_quantization", False)) + or not hf_quantizer.check_quantized_param(self, param_value=value, param_name=key, state_dict={}) + ): + set_module_tensor_to_device(self, key, "cpu", value) + else: + hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict, unexpected_keys) + + def _initialize_missing_keys( + self, + loaded_keys: List[str], + ignore_mismatched_sizes: bool, + is_quantized: bool, + ) -> "PreTrainedModel": + """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to + `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to + be initialized correctly (i.e. weight initialization distribution). + Also take care of setting the `_is_hf_initialized` flag for keys that are not missing. + """ + if not ignore_mismatched_sizes: + not_initialized_submodules = set_initialized_submodules(self, loaded_keys) + # If we're about to tie the output embeds to the input embeds we don't need to init them + if ( + hasattr(self.config.get_text_config(decoder=True), "tie_word_embeddings") + and self.config.get_text_config(decoder=True).tie_word_embeddings + ): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + # Still need to initialize if there is a bias term since biases are not tied. + if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: + output_embeddings._is_hf_initialized = True + else: + not_initialized_submodules = dict(self.named_modules()) + # This will only initialize submodules that are not marked as initialized by the line above. + if is_deepspeed_zero3_enabled() and not is_quantized: + import deepspeed + + not_initialized_parameters = list( + set( + itertools.chain.from_iterable( + submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() + ) + ) + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + self.apply(self._initialize_weights) + else: + self.apply(self._initialize_weights) + + def get_parameter_or_buffer(self, target: str): + """ + Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines + `get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a + leaf of the model. + """ + try: + return self.get_parameter(target) + except AttributeError: + pass + try: + return self.get_buffer(target) + except AttributeError: + pass + raise AttributeError(f"`{target}` is neither a parameter nor a buffer.") + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: @@ -5870,12 +5756,11 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module: return model -def expand_device_map(device_map, param_names, start_prefix): +def expand_device_map(device_map, param_names): """ Expand a device map to return the correspondence parameter name to device. """ new_device_map = {} - param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)] for module, device in device_map.items(): new_device_map.update( {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} @@ -5896,26 +5781,24 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, if not len(accelerator_device_map): return + tp_plan_regex = ( + re.compile("|".join([re.escape(plan) for plan in model._tp_plan])) + if _torch_distributed_available and torch.distributed.is_initialized() + else None + ) + parameter_count = defaultdict(lambda: 0) allocation_factor = 1 if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2: allocation_factor = 2 for param_name, device in accelerator_device_map.items(): - try: - param = getattr(model, param_name) - except AttributeError: - if "." in param_name: - param_name, param_type = param_name.rsplit(".", 1) - param = getattr(model.get_submodule(param_name), param_type) - else: - param = model.get_buffer(param_name) - + param = model.get_parameter_or_buffer(param_name) param_size = int(math.prod(param.shape) * allocation_factor) - if _torch_distributed_available and torch.distributed.is_initialized(): - generic_name = re.sub(r"\d+", "*", param_name) - param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1 + if tp_plan_regex is not None: + generic_name = re.sub(r"\.\d+\.", ".*.", param_name) + param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1 parameter_count[device] += param_size @@ -5931,14 +5814,10 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) -def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): +def get_disk_only_shard_files(device_map, weight_map): """ Returns the list of shard files containing only weights offloaded to disk. """ - - weight_map = { - p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix) - } files_content = collections.defaultdict(list) for weight_name, filename in weight_map.items(): while len(weight_name) > 0 and weight_name not in device_map: diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 1b237caaba..134571014f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -25,7 +25,7 @@ from typing import Dict, Optional, Union from ...configuration_utils import PretrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...feature_extraction_utils import FeatureExtractionMixin -from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging +from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, @@ -220,7 +220,7 @@ def get_feature_extractor_config( raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token - resolved_config_file = get_file_from_repo( + resolved_config_file = cached_file( pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, cache_dir=cache_dir, @@ -230,6 +230,9 @@ def get_feature_extractor_config( token=token, revision=revision, local_files_only=local_files_only, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, ) if resolved_config_file is None: logger.info( diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index fedf1070e0..40c45fa94b 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -29,7 +29,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast from ...utils import ( CONFIG_NAME, IMAGE_PROCESSOR_NAME, - get_file_from_repo, + cached_file, is_timm_config_dict, is_timm_local_checkpoint, is_torchvision_available, @@ -288,7 +288,7 @@ def get_image_processor_config( raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token - resolved_config_file = get_file_from_repo( + resolved_config_file = cached_file( pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, cache_dir=cache_dir, @@ -298,6 +298,9 @@ def get_image_processor_config( token=token, revision=revision, local_files_only=local_files_only, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, ) if resolved_config_file is None: logger.info( diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index d29d3f8d1a..53ddcd3d1e 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -28,7 +28,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin from ...image_processing_utils import ImageProcessingMixin from ...processing_utils import ProcessorMixin from ...tokenization_utils import TOKENIZER_CONFIG_FILE -from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging +from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, @@ -254,15 +254,21 @@ class AutoProcessor: processor_auto_map = None # First, let's see if we have a processor or preprocessor config. - # Filter the kwargs for `get_file_from_repo`. - get_file_from_repo_kwargs = { - key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs + # Filter the kwargs for `cached_file`. + cached_file_kwargs = { + key: kwargs[key] for key in inspect.signature(cached_file).parameters.keys() if key in kwargs } + # We don't want to raise + cached_file_kwargs.update( + { + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_raise_exceptions_for_connection_errors": False, + } + ) # Let's start by checking whether the processor class is saved in a processor config - processor_config_file = get_file_from_repo( - pretrained_model_name_or_path, PROCESSOR_NAME, **get_file_from_repo_kwargs - ) + processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs) if processor_config_file is not None: config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs) processor_class = config_dict.get("processor_class", None) @@ -271,8 +277,8 @@ class AutoProcessor: if processor_class is None: # If not found, let's check whether the processor class is saved in an image processor config - preprocessor_config_file = get_file_from_repo( - pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs + preprocessor_config_file = cached_file( + pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs ) if preprocessor_config_file is not None: config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) @@ -291,8 +297,8 @@ class AutoProcessor: if processor_class is None: # Next, let's check whether the processor class is saved in a tokenizer - tokenizer_config_file = get_file_from_repo( - pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs + tokenizer_config_file = cached_file( + pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs ) if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as reader: diff --git a/src/transformers/models/bark/processing_bark.py b/src/transformers/models/bark/processing_bark.py index 0bed6ca79f..5fa5cd1916 100644 --- a/src/transformers/models/bark/processing_bark.py +++ b/src/transformers/models/bark/processing_bark.py @@ -25,7 +25,7 @@ import numpy as np from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessorMixin from ...utils import logging -from ...utils.hub import get_file_from_repo +from ...utils.hub import cached_file from ..auto import AutoTokenizer @@ -86,7 +86,7 @@ class BarkProcessor(ProcessorMixin): """ if speaker_embeddings_dict_path is not None: - speaker_embeddings_path = get_file_from_repo( + speaker_embeddings_path = cached_file( pretrained_processor_name_or_path, speaker_embeddings_dict_path, subfolder=kwargs.pop("subfolder", None), @@ -97,6 +97,9 @@ class BarkProcessor(ProcessorMixin): local_files_only=kwargs.pop("local_files_only", False), token=kwargs.pop("use_auth_token", None), revision=kwargs.pop("revision", None), + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, ) if speaker_embeddings_path is None: logger.warning( @@ -182,7 +185,7 @@ class BarkProcessor(ProcessorMixin): f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]." ) - path = get_file_from_repo( + path = cached_file( self.speaker_embeddings.get("repo_or_path", "/"), voice_preset_paths[key], subfolder=kwargs.pop("subfolder", None), @@ -193,6 +196,9 @@ class BarkProcessor(ProcessorMixin): local_files_only=kwargs.pop("local_files_only", False), token=kwargs.pop("use_auth_token", None), revision=kwargs.pop("revision", None), + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, ) if path is None: raise ValueError( diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index cd68b391ba..45bb6aa49f 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -544,7 +544,7 @@ class CvtPreTrainedModel(PreTrainedModel): elif isinstance(module, CvtStage): if self.config.cls_token[module.stage]: module.cls_token.data = nn.init.trunc_normal_( - torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range + module.cls_token.data, mean=0.0, std=self.config.initializer_range ) diff --git a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py index a06b2e830d..280b52a1b9 100644 --- a/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py +++ b/src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py @@ -35,7 +35,7 @@ from torch import Tensor from vissl.models.model_helpers import get_trunk_forward_outputs from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel -from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict from transformers.utils import logging @@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_ our_model_func = RegNetModel if "in1k" in model_name: our_model_func = RegNetForImageClassification - our_model = our_model_func(our_config) - # place our model to the meta device (so remove all the weights) - our_model.to(torch.device("meta")) + with torch.device("meta"): + our_model = our_model_func(our_config) logger.info("Loading state_dict in our model.") # load state dict state_dict_keys = our_model.state_dict().keys() - PreTrainedModel._load_pretrained_model_low_mem( - our_model, state_dict_keys, [save_directory / f"{model_name}.pth"] + state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True) + fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} + _load_state_dict_into_meta_model( + our_model, + fixed_state_dict, + start_prefix="", + expected_keys=state_dict_keys, ) logger.info("Finally, pushing!") # push it to hub diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index ff5cb5a838..48c1d85825 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -113,7 +113,7 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): Override original method to fix state_dict keys on load for cases when weights are loaded without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint). """ - state_dict = self._fix_state_dict_keys_on_load(state_dict) + state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} return super().load_state_dict(state_dict, *args, **kwargs) def _init_weights(self, module): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 9561666db7..bdcb273c7e 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -91,7 +91,6 @@ from .hub import ( define_sagemaker_information, download_url, extract_commit_hash, - get_file_from_repo, has_file, http_user_agent, is_offline_mode, diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 21d6d04489..01d19c2140 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -40,6 +40,7 @@ from huggingface_hub import ( create_repo, hf_hub_download, hf_hub_url, + snapshot_download, try_to_load_from_cache, ) from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get @@ -47,7 +48,6 @@ from huggingface_hub.utils import ( EntryNotFoundError, GatedRepoError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, OfflineModeIsEnabled, RepositoryNotFoundError, @@ -69,7 +69,6 @@ from .import_utils import ( is_torch_available, is_training_run_on_sagemaker, ) -from .logging import tqdm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -209,21 +208,7 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] def cached_file( path_or_repo_id: Union[str, os.PathLike], filename: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: Optional[bool] = None, - proxies: Optional[Dict[str, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - subfolder: str = "", - repo_type: Optional[str] = None, - user_agent: Optional[Union[str, Dict[str, str]]] = None, - _raise_exceptions_for_gated_repo: bool = True, - _raise_exceptions_for_missing_entries: bool = True, - _raise_exceptions_for_connection_errors: bool = True, - _commit_hash: Optional[str] = None, - **deprecated_kwargs, + **kwargs, ) -> Optional[str]: """ Tries to locate a file in a local folder and repo, downloads and cache it if necessary. @@ -231,7 +216,6 @@ def cached_file( Args: path_or_repo_id (`str` or `os.PathLike`): This can be either: - - a string, the *model id* of a model repo on huggingface.co. - a path to a *directory* potentially containing the file. filename (`str`): @@ -274,6 +258,94 @@ def cached_file( Examples: + ```python + # Download a model weight from the Hub and cache it. + model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin") + ``` + """ + file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs) + file = file[0] if file is not None else file + return file + + +def cached_files( + path_or_repo_id: Union[str, os.PathLike], + filenames: List[str], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + repo_type: Optional[str] = None, + user_agent: Optional[Union[str, Dict[str, str]]] = None, + _raise_exceptions_for_gated_repo: bool = True, + _raise_exceptions_for_missing_entries: bool = True, + _raise_exceptions_for_connection_errors: bool = True, + _commit_hash: Optional[str] = None, + **deprecated_kwargs, +) -> Optional[str]: + """ + Tries to locate several files in a local folder and repo, downloads and cache them if necessary. + + Args: + path_or_repo_id (`str` or `os.PathLike`): + This can be either: + - a string, the *model id* of a model repo on huggingface.co. + - a path to a *directory* potentially containing the file. + filenames (`List[str]`): + The name of all the files to locate in `path_or_repo`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + repo_type (`str`, *optional*): + Specify the repo type (useful when downloading from a space for instance). + + Private args: + _raise_exceptions_for_gated_repo (`bool`): + if False, do not raise an exception for gated repo error but return None. + _raise_exceptions_for_missing_entries (`bool`): + if False, do not raise an exception for missing entries but return None. + _raise_exceptions_for_connection_errors (`bool`): + if False, do not raise an exception for connection errors but return None. + _commit_hash (`str`, *optional*): + passed when we are chaining several calls to various files (e.g. when loading a tokenizer or + a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache. + + + + Passing `token=True` is required when you want to use a private model. + + + + Returns: + `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo). + + Examples: + ```python # Download a model weight from the Hub and cache it. model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin") @@ -289,144 +361,176 @@ def cached_file( raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") token = use_auth_token - # Private arguments - # _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return - # None. - # _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return - # None. - # _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return - # None. - # _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or - # a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache. if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True if subfolder is None: subfolder = "" + # Add folder to filenames + full_filenames = [os.path.join(subfolder, file) for file in filenames] + path_or_repo_id = str(path_or_repo_id) - full_filename = os.path.join(subfolder, filename) - if os.path.isdir(path_or_repo_id): - resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) - if not os.path.isfile(resolved_file): - if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]: - raise EnvironmentError( - f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " - f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files." - ) - else: - return None - return resolved_file + existing_files = [] + for filename in full_filenames: + if os.path.isdir(path_or_repo_id): + resolved_file = os.path.join(path_or_repo_id, filename) + if not os.path.isfile(resolved_file): + if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"): + revision_ = "main" if revision is None else revision + raise EnvironmentError( + f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout " + f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files." + ) + else: + return None + existing_files.append(resolved_file) + + # All files exist + if len(existing_files) == len(full_filenames): + return existing_files if cache_dir is None: cache_dir = TRANSFORMERS_CACHE if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + existing_files = [] + file_counter = 0 if _commit_hash is not None and not force_download: - # If the file is cached under that commit hash, we return it directly. - resolved_file = try_to_load_from_cache( - path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type - ) - if resolved_file is not None: - if resolved_file is not _CACHED_NO_EXIST: - return resolved_file - elif not _raise_exceptions_for_missing_entries: - return None - else: - raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.") + for filename in full_filenames: + # If the file is cached under that commit hash, we return it directly. + resolved_file = try_to_load_from_cache( + path_or_repo_id, filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type + ) + if resolved_file is not None: + if resolved_file is not _CACHED_NO_EXIST: + file_counter += 1 + existing_files.append(resolved_file) + elif not _raise_exceptions_for_missing_entries: + file_counter += 1 + else: + raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.") + + # Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries + if file_counter == len(full_filenames): + return existing_files if len(existing_files) > 0 else None user_agent = http_user_agent(user_agent) + # download the files if needed try: - # Load from URL or cache if already cached - resolved_file = hf_hub_download( - path_or_repo_id, - filename, - subfolder=None if len(subfolder) == 0 else subfolder, - repo_type=repo_type, - revision=revision, - cache_dir=cache_dir, - user_agent=user_agent, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - token=token, - local_files_only=local_files_only, + if len(full_filenames) == 1: + # This is slightly better for only 1 file + hf_hub_download( + path_or_repo_id, + filenames[0], + subfolder=None if len(subfolder) == 0 else subfolder, + repo_type=repo_type, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + else: + snapshot_download( + path_or_repo_id, + allow_patterns=full_filenames, + repo_type=repo_type, + revision=revision, + cache_dir=cache_dir, + user_agent=user_agent, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + except Exception as e: + # We cannot recover from them + if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError): + raise EnvironmentError( + f"{path_or_repo_id} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token " + "having permission to this repo either by logging in with `huggingface-cli login` or by passing " + "`token=`" + ) from e + elif isinstance(e, RevisionNotFoundError): + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists " + "for this model name. Check the model page at " + f"'https://huggingface.co/{path_or_repo_id}' for available revisions." + ) from e + + # Now we try to recover if we can find all files correctly in the cache + resolved_files = [ + _get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames + ] + if all(file is not None for file in resolved_files): + return resolved_files + + # Raise based on the flags. Note that we will raise for missing entries at the very end, even when + # not entering this Except block, as it may also happen when `snapshot_download` does not raise + if isinstance(e, GatedRepoError): + if not _raise_exceptions_for_gated_repo: + return None + raise EnvironmentError( + "You are trying to access a gated repo.\nMake sure to have access to it at " + f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}" + ) from e + elif isinstance(e, LocalEntryNotFoundError): + if not _raise_exceptions_for_connection_errors: + return None + # Here we only raise if both flags for missing entry and connection errors are True (because it can be raised + # even when `local_files_only` is True, in which case raising for connections errors only would not make sense) + elif _raise_exceptions_for_missing_entries: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the" + f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) from e + # snapshot_download will not raise EntryNotFoundError, but hf_hub_download can. If this is the case, it will be treated + # later on anyway and re-raised if needed + elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError): + if not _raise_exceptions_for_connection_errors: + return None + raise EnvironmentError( + f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}" + ) + + resolved_files = [ + _get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames + ] + # If there are any missing file and the flag is active, raise + if any(file is None for file in resolved_files) and _raise_exceptions_for_missing_entries: + missing_entries = [original for original, resolved in zip(full_filenames, resolved_files) if resolved is None] + # Last escape + if len(resolved_files) == 1 and missing_entries[0] == os.path.join(subfolder, "config.json"): + return None + # Now we raise for missing entries + revision_ = "main" if revision is None else revision + msg = f"a file named {missing_entries[0]}" if len(missing_entries) == 1 else f"files named {*missing_entries,}" + raise EnvironmentError( + f"{path_or_repo_id} does not appear to have {msg}. Checkout 'https://huggingface.co/{path_or_repo_id}/tree/{revision_}'" + "for available files." ) - except GatedRepoError as e: - resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) - if resolved_file is not None or not _raise_exceptions_for_gated_repo: - return resolved_file - raise EnvironmentError( - "You are trying to access a gated repo.\nMake sure to have access to it at " - f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}" - ) from e - except RepositoryNotFoundError as e: - raise EnvironmentError( - f"{path_or_repo_id} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token " - "having permission to this repo either by logging in with `huggingface-cli login` or by passing " - "`token=`" - ) from e - except RevisionNotFoundError as e: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists " - "for this model name. Check the model page at " - f"'https://huggingface.co/{path_or_repo_id}' for available revisions." - ) from e - except LocalEntryNotFoundError as e: - resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) - if ( - resolved_file is not None - or not _raise_exceptions_for_missing_entries - or not _raise_exceptions_for_connection_errors - ): - return resolved_file - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the" - f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named" - f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at" - " 'https://huggingface.co/docs/transformers/installation#offline-mode'." - ) from e - except EntryNotFoundError as e: - if not _raise_exceptions_for_missing_entries: - return None - if revision is None: - revision = "main" - if filename in ["config.json", f"{subfolder}/config.json"]: - return None - raise EnvironmentError( - f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " - f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files." - ) from e - except HTTPError as err: - resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision) - if resolved_file is not None or not _raise_exceptions_for_connection_errors: - return resolved_file - raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}") - except HFValidationError as e: - raise EnvironmentError( - f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub." - ) from e - return resolved_file + + # Remove potential missing entries (we can silently remove them at this point based on the flags) + resolved_files = [file for file in resolved_files if file is not None] + # Return `None` if the list is empty, coherent with other Exception when the flag is not active + resolved_files = None if len(resolved_files) == 0 else resolved_files + + return resolved_files -# TODO: deprecate `get_file_from_repo` or document it differently? -# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if -# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error. -# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed). +# TODO cyril: Deprecated and should be removed in 4.51 def get_file_from_repo( - path_or_repo: Union[str, os.PathLike], - filename: str, - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - resume_download: Optional[bool] = None, - proxies: Optional[Dict[str, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - local_files_only: bool = False, - subfolder: str = "", - **deprecated_kwargs, + *args, + **kwargs, ): """ Tries to locate a file in a local folder and repo, downloads and cache it if necessary. @@ -483,30 +587,15 @@ def get_file_from_repo( tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json") ``` """ - use_auth_token = deprecated_kwargs.pop("use_auth_token", None) - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") - token = use_auth_token - + logger.warning( + "`get_file_from_repo` is deprecated and will be removed in version 4.51. Use `cached_file` instead." + ) return cached_file( - path_or_repo_id=path_or_repo, - filename=filename, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, + *args, _raise_exceptions_for_gated_repo=False, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, + **kwargs, ) @@ -1023,45 +1112,22 @@ def get_checkpoint_shard_files( shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames] return shard_filenames, sharded_metadata - # At this stage pretrained_model_name_or_path is a model identifier on the Hub - cached_filenames = [] - # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of - # downloaded (if interrupted). - last_shard = try_to_load_from_cache( - pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash + # At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache, + # or download the files + cached_filenames = cached_files( + pretrained_model_name_or_path, + shard_filenames, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=_commit_hash, ) - show_progress_bar = last_shard is None or force_download - for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar): - try: - # Load from URL - cached_filename = cached_file( - pretrained_model_name_or_path, - shard_filename, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _commit_hash=_commit_hash, - ) - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " - "required according to the checkpoint index." - ) - except HTTPError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try" - " again after checking your internet connection." - ) - - cached_filenames.append(cached_filename) return cached_filenames, sharded_metadata diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b8c41c4ed4..c29f58f33f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2368,10 +2368,9 @@ class ModelTesterMixin: safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"}) model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True) - prefix = f"{model_reloaded.base_model_prefix}." params = dict(model_reloaded.named_parameters()) params.update(dict(model_reloaded.named_buffers())) - param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()} + param_names = set(params.keys()) missing_keys = set(infos["missing_keys"]) @@ -2383,9 +2382,8 @@ class ModelTesterMixin: ptrs[id_tensor_storage(tensor)].append(name) tied_params = [names for _, names in ptrs.items() if len(names) > 1] for group in tied_params: - group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group} # We remove the group from extra_missing if not all weights from group are in it - if len(group - extra_missing) > 0: + if len(set(group) - extra_missing) > 0: extra_missing = extra_missing - set(group) self.assertEqual( @@ -2399,15 +2397,14 @@ class ModelTesterMixin: # Remove nonpersistent buffers from missed_missing buffers = [n for n, _ in model_reloaded.named_buffers()] nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()} - nonpersistent_buffers = { - k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers - } missed_missing = missed_missing - nonpersistent_buffers if model_reloaded._keys_to_ignore_on_load_missing is None: expected_missing = set() else: - expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing) + expected_missing = set() + for pattern in model_reloaded._keys_to_ignore_on_load_missing: + expected_missing.update({k for k in param_names if re.search(pattern, k) is not None}) self.assertEqual( missed_missing, expected_missing, diff --git a/tests/utils/test_hub_utils.py b/tests/utils/test_hub_utils.py index aae9bd63cf..ec5887bd16 100644 --- a/tests/utils/test_hub_utils.py +++ b/tests/utils/test_hub_utils.py @@ -28,7 +28,6 @@ from transformers.utils import ( TRANSFORMERS_CACHE, WEIGHTS_NAME, cached_file, - get_file_from_repo, has_file, ) @@ -87,14 +86,8 @@ class GetFromCacheTests(unittest.TestCase): path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False) self.assertIsNone(path) - response_mock = mock.Mock() - response_mock.status_code = 500 - response_mock.headers = {} - response_mock.raise_for_status.side_effect = HTTPError - response_mock.json.return_value = {} - - # Under the mock environment we get a 500 error when trying to reach the tokenizer. - with mock.patch("requests.Session.request", return_value=response_mock) as mock_head: + # Under the mock environment, hf_hub_download will always raise an HTTPError + with mock.patch("transformers.utils.hub.hf_hub_download", side_effect=HTTPError) as mock_head: path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False) self.assertIsNone(path) # This check we did call the fake head request @@ -117,18 +110,45 @@ class GetFromCacheTests(unittest.TestCase): assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir) def test_get_file_from_repo_distant(self): - # `get_file_from_repo` returns None if the file does not exist - self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt")) + # should return None if the file does not exist + self.assertIsNone( + cached_file( + "google-bert/bert-base-cased", + "ahah.txt", + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + ) # The function raises if the repository does not exist. with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): - get_file_from_repo("bert-base-case", CONFIG_NAME) + cached_file( + "bert-base-case", + CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) # The function raises if the revision does not exist. with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): - get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha") + cached_file( + "google-bert/bert-base-cased", + CONFIG_NAME, + revision="ahaha", + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) - resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME) + resolved_file = cached_file( + "google-bert/bert-base-cased", + CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) # The name is the cached name which is not very easy to test, so instead we load the content. config = json.loads(open(resolved_file, "r").read()) self.assertEqual(config["hidden_size"], 768) @@ -137,9 +157,26 @@ class GetFromCacheTests(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dir: filename = Path(tmp_dir) / "a.txt" filename.touch() - self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename)) + self.assertEqual( + cached_file( + tmp_dir, + "a.txt", + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ), + str(filename), + ) - self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt")) + self.assertIsNone( + cached_file( + tmp_dir, + "b.txt", + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + ) def test_get_file_gated_repo(self): """Test download file from a gated repo fails with correct message when not authenticated.""" diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index e5bb3490de..a85b598c08 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -14,7 +14,6 @@ # limitations under the License. import copy import glob -import itertools import json import os import os.path @@ -525,13 +524,12 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16) - # TODO @ARTHURZUCKER FIX THIS # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what - # LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] - # model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") - # self.assertEqual(model.language_model.dtype, torch.float32) - # self.assertEqual(model.vision_tower.dtype, torch.bfloat16) - # self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) + LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] + model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto") + self.assertEqual(model.language_model.dtype, torch.float32) + self.assertEqual(model.vision_tower.dtype, torch.bfloat16) + self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32) # torch.set_default_dtype() supports only float dtypes, so will fail with non-float type with self.assertRaises(ValueError): @@ -540,20 +538,6 @@ class ModelUtilsTest(TestCasePlus): TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"} ) - @require_torch - @unittest.skip("Broken by @arthurzucker because the fix was not correct. Knowing the context is super hard") - def test_model_from_pretrained_meta_device(self): - def is_on_meta(model_id, dtype): - with torch.device("meta"): - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype) - return all(value.device.type == "meta" for value in model.state_dict().values()) - - model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing") - dtypes = (None, "auto", torch.float16) - - for model_id, dtype in itertools.product(model_ids, dtypes): - self.assertTrue(is_on_meta(model_id, dtype)) - def test_model_from_pretrained_torch_dtype(self): # test that the model can be instantiated with dtype of either # 1. explicit from_pretrained's torch_dtype argument