From b473617d639bc43f05050020abe9ad37d25c5240 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 25 Mar 2022 11:59:25 -0400 Subject: [PATCH] Checkpoint sharding (#16343) * Sharded checkpoint support * Handle distant sharded checkpoints * Add tests * TODO is done * Apply suggestions from code review Co-authored-by: Stas Bekman * Fix docstring * Add example and format * Address review comments * More review comments * End of merge * Revert unintentional change * VsCode what did you do? * Style * Changes * Address final comments * Quality * Moar tests * Move import beneath is_pt_available Co-authored-by: Stas Bekman --- src/transformers/file_utils.py | 1 + src/transformers/modeling_utils.py | 730 +++++++++++++++++++++++------ src/transformers/utils/__init__.py | 1 + tests/test_modeling_common.py | 120 ++++- 4 files changed, 710 insertions(+), 142 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index c71a9a9b85..4b93c496ce 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -52,6 +52,7 @@ from .utils import ( USE_JAX, USE_TF, USE_TORCH, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ContextManagers, DummyObject, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bee88d3cf0..21b8f22691 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -15,11 +15,15 @@ # limitations under the License. import inspect +import json import os import re +import shutil +import tempfile from contextlib import contextmanager from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch @@ -38,6 +42,7 @@ from .utils import ( FLAX_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, EntryNotFoundError, ModelOutput, @@ -45,7 +50,6 @@ from .utils import ( RepositoryNotFoundError, RevisionNotFoundError, cached_path, - copy_func, has_file, hf_bucket_url, is_offline_mode, @@ -148,6 +152,272 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil return first_tuple[1].dtype +def convert_file_size_to_int(size: Union[int, str]): + """ + Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + + Args: + size (`int` or `str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> convert_file_size_to_int("1MB") + 1048576 + ``` + """ + if isinstance(size, int): + return size + if size.upper().endswith("GIB"): + return int(size[:-3]) * (2**30) + if size.upper().endswith("MIB"): + return int(size[:-3]) * (2**20) + if size.upper().endswith("KIB"): + return int(size[:-3]) * (2**10) + if size.upper().endswith("GB"): + return int(size[:-2]) * (10**9) + if size.upper().endswith("MB"): + return int(size[:-2]) * (10**6) + if size.upper().endswith("KB"): + return int(size[:-2]) * (10**3) + raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search("[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[key] = weight + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {WEIGHTS_NAME: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + local_files_only=False, + use_auth_token=None, + user_agent=None, + revision=None, + mirror=None, +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(list(set(index["weight_map"].values()))) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] + return shard_filenames, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + cached_filenames = [] + for shard_filename in shard_filenames: + shard_url = hf_bucket_url( + pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror + ) + + try: + # Load from URL + cached_filename = cached_path( + shard_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " + "required according to the checkpoint index." + ) + except HTTPError: + raise EnvironmentError( + f"We couldn't connect to 'https://huggingface.co/' to load {shard_filename}. You should try again " + "after checking your internet connection." + ) + + cached_filenames.append(cached_filename) + + return cached_filenames, sharded_metadata + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + if is_deepspeed_zero3_enabled(): + import deepspeed + + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load, prefix=start_prefix) + + return error_msgs + + class ModuleUtilsMixin: """ A few utilities for `torch.nn.Modules`, to be used as a mixin. @@ -1004,6 +1274,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix state_dict: Optional[dict] = None, save_function: Callable = torch.save, push_to_hub: bool = False, + max_shard_size: Union[int, str] = "10GB", **kwargs, ): """ @@ -1035,6 +1306,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -1078,11 +1360,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if ignore_key in state_dict.keys(): del state_dict[ignore_key] - # If we save using the predefined names, we can load using `from_pretrained` - output_model_file = os.path.join(save_directory, WEIGHTS_NAME) - save_function(state_dict, output_model_file) + # Shard the model if it is too big. + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size) - logger.info(f"Model weights saved in {output_model_file}") + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) @@ -1293,6 +1596,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: model_kwargs = kwargs + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None # Load model if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) @@ -1309,6 +1616,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True # At this stage we don't have a weight file so we will raise an error. elif os.path.isfile( os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") @@ -1382,29 +1693,51 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) except EntryNotFoundError: if filename == WEIGHTS_NAME: - has_file_kwargs = { - "revision": revision, - "mirror": mirror, - "proxies": proxies, - "use_auth_token": use_auth_token, - } - if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but " - "there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those " - "weights." + try: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + mirror=mirror, ) - elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but " - "there is a file for Flax weights. Use `from_flax=True` to load this model from those " - "weights." - ) - else: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, " - f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, ) + is_sharded = True + except EntryNotFoundError: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "mirror": mirror, + "proxies": proxies, + "use_auth_token": use_auth_token, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but " + "there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those " + "weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but " + "there is a file for Flax weights. Use `from_flax=True` to load this model from those " + "weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, " + f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." @@ -1439,29 +1772,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: resolved_archive_file = None + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + revision=revision, + mirror=mirror, + ) + # load pt weights early so that we know which dtype to init the model under if from_pt: - if state_dict is None: - try: - state_dict = torch.load(resolved_archive_file, map_location="cpu") - except Exception as e: - try: - with open(resolved_archive_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " - f"at '{resolved_archive_file}'. " - "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." - ) - + if not is_sharded: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first @@ -1471,7 +1803,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if torch_dtype is not None: if isinstance(torch_dtype, str): if torch_dtype == "auto": - torch_dtype = next(iter(state_dict.values())).dtype + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif not is_sharded: + torch_dtype = next(iter(state_dict.values())).dtype + else: + one_state_dict = load_state_dict(resolved_archive_file) + torch_dtype = next(iter(one_state_dict.values())).dtype + del one_state_dict # free CPU memory else: raise ValueError( f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}" @@ -1480,8 +1819,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if low_cpu_mem_usage: # save the keys - loaded_state_dict_keys = [k for k in state_dict.keys()] - del state_dict # free CPU memory - will reload again later + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = load_state_dict(resolved_archive_file) + loaded_state_dict_keys = [k for k in state_dict.keys()] + del state_dict # free CPU memory - will reload again later config.name_or_path = pretrained_model_name_or_path @@ -1534,13 +1877,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif from_pt: if low_cpu_mem_usage: - cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file) + cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file) else: - model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model( + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, + resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, _fast_init=_fast_init, ) @@ -1562,31 +1907,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return model @classmethod - def _load_state_dict_into_model( - cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, ): - - # Convert old format to new format if needed from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if "gamma" in key: - new_key = key.replace("gamma", "weight") - if "beta" in key: - new_key = key.replace("beta", "bias") - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) - loaded_keys = list(state_dict.keys()) + loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"] prefix = model.base_model_prefix + def _fix_key(key): + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + return key + + loaded_keys = [_fix_key(key) for key in loaded_keys] + if len(prefix) > 0: has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) @@ -1608,28 +1953,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - if remove_prefix_from_model: - # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. - model_key = f"{prefix}.{checkpoint_key}" - elif add_prefix_to_model: - # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. - model_key = ".".join(checkpoint_key.split(".")[1:]) - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_missing is not None: @@ -1648,35 +1971,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix for module in uninitialized_modules: model._init_weights(module) - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, "_metadata", None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: nn.Module, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) - if is_deepspeed_zero3_enabled(): - import deepspeed - - # because zero3 puts placeholders in model params, this context - # manager gathers (unpartitions) the params of the current layer, then loads from - # the state dict and then re-partitions them again - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): - if torch.distributed.get_rank() == 0: - module._load_from_state_dict(*args) - else: - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model @@ -1690,7 +1984,61 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "properly saved?" ) - load(model_to_load, prefix=start_prefix) + if state_dict is not None: + # Whole checkpoint + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + else: + # Sharded checkpoint + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + + error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) @@ -1755,7 +2103,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return retrieved_modules @classmethod - def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file): + def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file): """ This is an experimental function that loads the model using ~1.x model size CPU memory @@ -1772,7 +2120,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. """ - require_version_core("torch>=1.9") if is_deepspeed_zero3_enabled(): raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3") @@ -1806,19 +2153,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix new_val = new_val.to("meta") setattr(submodule, param_name, new_val) - # only now can load state_dict - state_dict = torch.load(resolved_archive_file, map_location="cpu") + # only now can load state_dict(s) + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] - # materialize state_dict entries one by one on CPU - for k in loaded_state_dict_keys: - submodule, param_name = find_submodule_and_param_name(model, k) - if submodule is not None: - new_val = state_dict[k] - if isinstance(getattr(submodule, param_name), torch.nn.Parameter): - new_val = torch.nn.Parameter(new_val) - setattr(submodule, param_name, new_val) + for archive_file in resolved_archive_file: + state_dict = torch.load(resolved_archive_file, map_location="cpu") - del state_dict + # materialize state_dict entries one by one on CPU + for k in loaded_state_dict_keys: + submodule, param_name = find_submodule_and_param_name(model, k) + if submodule is not None: + new_val = state_dict[k] + if isinstance(getattr(submodule, param_name), torch.nn.Parameter): + new_val = torch.nn.Parameter(new_val) + setattr(submodule, param_name, new_val) + + del state_dict @classmethod def register_for_auto_class(cls, auto_class="AutoModel"): @@ -1846,12 +2197,109 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cls._auto_class = auto_class + def push_to_hub( + self, + repo_path_or_name: Optional[str] = None, + repo_url: Optional[str] = None, + use_temp_dir: bool = False, + commit_message: str = "add model", + organization: Optional[str] = None, + private: Optional[bool] = None, + use_auth_token: Optional[Union[bool, str]] = None, + max_shard_size: Union[int, str] = "10GB", + **model_card_kwargs + ) -> str: + """ + Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`. -# To update the docstring, we need to copy the method, otherwise we change the original docstring. -PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) -PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format( - object="model", object_class="AutoModel", object_files="model checkpoint" -) + Parameters: + repo_path_or_name (`str`, *optional*): + Can either be a repository name for your model in the Hub or a path to a local folder (in which case + the repository will have the name of that local folder). If not specified, will default to the name + given by `repo_url` and a local directory with that name will be created. + repo_url (`str`, *optional*): + Specify this in case you want to push to an existing repository in the hub. If unspecified, a new + repository will be created in your namespace (unless you specify an `organization`) with `repo_name`. + use_temp_dir (`bool`, *optional*, defaults to `False`): + Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the + current working directory. This will slow things down if you are making changes in an existing repo + since you will need to clone the repo before every push. + commit_message (`str`, *optional*, defaults to `"add model"`): + Message to commit while pushing. + organization (`str`, *optional*): + Organization in which you want to push your {object} (you must be a member of this organization). + private (`bool`, *optional*): + Whether or not the repository created should be private (requires a paying subscription). + use_auth_token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if + `repo_url` is not specified. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + Returns: + `str`: The url of the commit of your {object} in the given repository. + + Examples: + + ```python + from transformers import AutoModel + + model = AutoModel.from_pretrained("bert-base-cased") + + # Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the + # *my-finetuned-bert* folder. + model.push_to_hub("my-finetuned-bert") + + # Push the model to your namespace with the name "my-finetuned-bert" with no local clone. + model.push_to_hub("my-finetuned-bert", use_temp_dir=True) + + # Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the + # *my-finetuned-bert* folder. + model.push_to_hub("my-finetuned-bert", organization="huggingface") + + # Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*. + model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert") + ``` + """ + if use_temp_dir: + # Make sure we use the right `repo_name` for the `repo_url` before replacing it. + if repo_url is None: + if use_auth_token is None: + use_auth_token = True + repo_name = Path(repo_path_or_name).name + repo_url = self._get_repo_url_from_name( + repo_name, organization=organization, private=private, use_auth_token=use_auth_token + ) + repo_path_or_name = tempfile.mkdtemp() + + # Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo. + repo = self._create_or_get_repo( + repo_path_or_name=repo_path_or_name, + repo_url=repo_url, + organization=organization, + private=private, + use_auth_token=use_auth_token, + ) + # Save the files in the cloned repo + self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size) + + # Commit and push! + url = self._push_to_hub(repo, commit_message=commit_message) + + # Clean up! Clean up! Everybody everywhere! + if use_temp_dir: + shutil.rmtree(repo_path_or_name) + + return url class Conv1D(nn.Module): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b8c6cb65af..af326b53e8 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -136,6 +136,7 @@ from .import_utils import ( WEIGHTS_NAME = "pytorch_model.bin" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" TF2_WEIGHTS_NAME = "tf_model.h5" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6ed8dd3c97..8e24175552 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -55,7 +55,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available +from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available sys.path.append(str(Path(__file__).parent.parent / "utils")) @@ -90,6 +90,7 @@ if is_torch_available(): T5Config, T5ForConditionalGeneration, ) + from transformers.modeling_utils import shard_checkpoint if is_flax_available(): import jax.numpy as jnp @@ -2352,6 +2353,123 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_shard_checkpoint(self): + # This is the model we will use, total size 340,000 bytes. + model = torch.nn.Sequential( + torch.nn.Linear(100, 200, bias=False), # size 80,000 + torch.nn.Linear(200, 200, bias=False), # size 160,000 + torch.nn.Linear(200, 100, bias=False), # size 80,000 + torch.nn.Linear(100, 50, bias=False), # size 20,000 + ) + state_dict = model.state_dict() + + with self.subTest("No shard when max size is bigger than model size"): + shards, index = shard_checkpoint(state_dict) + self.assertIsNone(index) + self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict}) + + with self.subTest("Test sharding, no weights bigger than max size"): + shards, index = shard_checkpoint(state_dict, max_shard_size="300kB") + # Split is first two layers then last two. + self.assertDictEqual( + index, + { + "metadata": {"total_size": 340000}, + "weight_map": { + "0.weight": "pytorch_model-00001-of-00002.bin", + "1.weight": "pytorch_model-00001-of-00002.bin", + "2.weight": "pytorch_model-00002-of-00002.bin", + "3.weight": "pytorch_model-00002-of-00002.bin", + }, + }, + ) + + shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]} + shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]} + self.assertDictEqual( + shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2} + ) + + with self.subTest("Test sharding with weights bigger than max size"): + shards, index = shard_checkpoint(state_dict, max_shard_size="100kB") + # Split is first layer, second layer then last 2. + self.assertDictEqual( + index, + { + "metadata": {"total_size": 340000}, + "weight_map": { + "0.weight": "pytorch_model-00001-of-00003.bin", + "1.weight": "pytorch_model-00002-of-00003.bin", + "2.weight": "pytorch_model-00003-of-00003.bin", + "3.weight": "pytorch_model-00003-of-00003.bin", + }, + }, + ) + + shard1 = {"0.weight": state_dict["0.weight"]} + shard2 = {"1.weight": state_dict["1.weight"]} + shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]} + self.assertDictEqual( + shards, + { + "pytorch_model-00001-of-00003.bin": shard1, + "pytorch_model-00002-of-00003.bin": shard2, + "pytorch_model-00003-of-00003.bin": shard3, + }, + ) + + def test_checkpoint_sharding_local(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + with tempfile.TemporaryDirectory() as tmp_dir: + # We use the same folder for various sizes to make sure a new save erases the old checkpoint. + for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]: + model.save_pretrained(tmp_dir, max_shard_size=max_size) + + # Get each shard file and its size + shard_to_size = {} + for shard in os.listdir(tmp_dir): + if shard.endswith(".bin"): + shard_file = os.path.join(tmp_dir, shard) + shard_to_size[shard_file] = os.path.getsize(shard_file) + + index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME) + # Check there is an index but no regular weight file + self.assertTrue(os.path.isfile(index_file)) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME))) + + # Check a file is bigger than max_size only when it has a single weight + for shard_file, size in shard_to_size.items(): + if max_size.endswith("kiB"): + max_size_int = int(max_size[:-3]) * 2**10 + else: + max_size_int = int(max_size[:-2]) * 10**3 + # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than + # the size asked for (since we count parameters) + if size >= max_size_int + 50000: + state_dict = torch.load(shard_file) + self.assertEqual(len(state_dict), 1) + + # Check the index and the shard files found match + with open(index_file, "r", encoding="utf-8") as f: + index = json.loads(f.read()) + + all_shards = set(index["weight_map"].values()) + shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin")) + self.assertSetEqual(all_shards, shards_found) + + # Finally, check the model can be reloaded + new_model = BertModel.from_pretrained(tmp_dir) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + def test_checkpoint_sharding_from_hub(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") + # the model above is the same as the model below, just a sharded version. + ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + for p1, p2 in zip(model.parameters(), ref_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock()