From 7cced021fa8ddc59f0f77384300760d34545394e Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 21 Jun 2022 18:01:08 +0200 Subject: [PATCH] TF Sharded (#17713) * initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante Co-authored-by: Patrick von Platen * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_tf_utils.py | 414 ++++++++++++++++++++++++-- src/transformers/utils/__init__.py | 1 + src/transformers/utils/hub.py | 118 +++++++- tests/test_modeling_tf_common.py | 131 +++++++- 4 files changed, 633 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 70d0d489a8..62b0c48880 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -16,7 +16,9 @@ """TF general model utils.""" import functools +import gc import inspect +import json import os import pickle import re @@ -33,7 +35,9 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor from tensorflow.python.keras.saving import hdf5_format from huggingface_hub import Repository, list_repo_files +from keras.saving.hdf5_format import save_attributes_to_hdf5_group from requests import HTTPError +from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from . import DataCollatorWithPadding, DefaultDataCollator from .activations_tf import get_tf_activation @@ -44,6 +48,7 @@ from .tf_utils import shape_list from .utils import ( DUMMY_INPUTS, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, WEIGHTS_NAME, EntryNotFoundError, @@ -554,9 +559,243 @@ def input_processing(func, config, input_ids, **kwargs): return output +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(tf.float32) + 4 + ``` + """ + if dtype == tf.bool: + return 1 / 8 + bit_search = re.search("[^\d](\d+)$", dtype.name) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def tf_shard_checkpoint(weights, max_shard_size="10GB"): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [] + current_block = [] + current_block_size = 0 + total_size = 0 + + for item in weights: + weight_size = item.numpy().size * dtype_byte_size(item.dtype) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = [] + current_block_size = 0 + + current_block.append(item) + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5") + shards[shard_file] = shard + for weight in shard: + weight_name = weight.name + weight_map[weight_name] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True): + """ + This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load + the TF weights from the shard file accordingly to their names and shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`tf.keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + missing_keys = [] + unexpected_keys = set() + saved_keys = set() + missmatched_keys = set() + + # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load + # the weight, we have to get rid of the first prefix of the name of the layer. + model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights) + model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)} + + for shard_file in shard_files: + state_dict = tf.io.read_file(shard_file) + saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard( + model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes + ) + saved_keys.update(saved_weight_names_set) + unexpected_keys.update(unexpected_keys_set) + missmatched_keys.update(missmatched_keys_set) + del state_dict + gc.collect() + + missing_keys = model_keys - saved_keys + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, missmatched_keys + + +def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False): + """ + Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys. + + Args: + model (`tf.keras.models.Model`): Model in which the weights are loaded + model_layer_map (`Dict`): A dictionnary mapping the layer name to the index of the layer in the model. + resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys + + Returns: + `tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the + shard file), one for the missmatched layers, and another one for the unexpected layers. + """ + saved_weight_names_set = set() + saved_weights = {} + missmatched_keys = set() + unexpected_keys = set() + # Read the H5 file + try: + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set( + hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names") + ) + weight_value_tuples = [] + + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] + for layer_name in saved_h5_model_layers_name: + h5_layer_object = sharded_checkpoint_file[layer_name] + saved_weights[layer_name] = np.asarray(h5_layer_object) + + saved_weight_names_set.add(layer_name) + + if layer_name not in model_layer_map: + unexpected_keys.add(layer_name) + else: + symbolic_weight = model.weights[model_layer_map[layer_name]] + + saved_weight_value = saved_weights[layer_name] + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) + except ValueError as e: + if ignore_mismatched_sizes: + missmatched_keys.add( + (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight)) + ) + continue + else: + raise e + else: + array = saved_weight_value + + # We create the tuple that will be loaded and add it to the final list + weight_value_tuples.append((symbolic_weight, array)) + + K.batch_set_value(weight_value_tuples) + + return saved_weight_names_set, unexpected_keys, missmatched_keys + + except Exception as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained" + " model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' " + f"at '{resolved_archive_file}'. " + "If you tried to load a TF model from a sharded checkpoint, you should try converting the model" + "by loading it in pytorch and saving it localy. A convertion script should be realeased soon." + ) + + def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): """ - Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. Args: model (`tf.keras.models.Model`): @@ -575,9 +814,11 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, mismatched_layers = [] # Read the H5 file - with h5py.File(resolved_archive_file, "r") as f: + with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file: # Retrieve the name of each layer from the H5 file - saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + saved_h5_model_layers_name = set( + hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names") + ) # Find the missing layers from the high level list of layers missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name) @@ -594,7 +835,7 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, # if layer_name from the H5 file belongs to the layers from the instantiated model if layer.name in saved_h5_model_layers_name: # Get the H5 layer object from its name - h5_layer_object = f[layer.name] + h5_layer_object = sharded_checkpoint_file[layer.name] # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weights = {} @@ -1624,7 +1865,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu """ raise NotImplementedError - def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs): + def save_pretrained( + self, + save_directory, + saved_model=False, + version=1, + push_to_hub=False, + max_shard_size: Union[int, str] = "10GB", + **kwargs + ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~TFPreTrainedModel.from_pretrained`] class method. @@ -1649,6 +1898,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -1679,8 +1939,48 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME) - self.save_weights(output_model_file) - logger.info(f"Model weights saved in {output_model_file}") + + shards, index = tf_shard_checkpoint(self.weights, max_shard_size) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if ( + filename.startswith(TF2_WEIGHTS_NAME[:-4]) + and os.path.isfile(full_filename) + and filename not in shards.keys() + ): + os.remove(full_filename) + + if index is None: + self.save_weights(output_model_file) + logger.info(f"Model weights saved in {output_model_file}") + else: + save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as index_file: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + index_file.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + for shard_file, shard in shards.items(): + with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: + save_attributes_to_hdf5_group( + shard_file, + "layer_names", + ["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard], + ) + + for layer in sorted(shard, key=lambda x: x.name): + param_dset = shard_file.create_dataset( + "/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype + ) + param_dset[:] = layer.numpy() if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) @@ -1844,6 +2144,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu else: model_kwargs = kwargs + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): @@ -1853,6 +2157,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) + is_sharded = True # At this stage we don't have a weight file so we will raise an error. elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): raise EnvironmentError( @@ -1906,23 +2214,45 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ) except EntryNotFoundError: if filename == TF2_WEIGHTS_NAME: - has_file_kwargs = { - "revision": revision, - "mirror": mirror, - "proxies": proxies, - "use_auth_token": use_auth_token, - } - if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} " - "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from " - "those weights." + try: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=TF2_WEIGHTS_INDEX_NAME, + revision=revision, + mirror=mirror, ) - else: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} " - f"or {WEIGHTS_NAME}." + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, ) + is_sharded = True + except EntryNotFoundError: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "mirror": mirror, + "proxies": proxies, + "use_auth_token": use_auth_token, + } + if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" + " load this model from those weights." + ) + else: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." @@ -1955,6 +2285,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu else: resolved_archive_file = None + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + revision=revision, + mirror=mirror, + ) + config.name_or_path = pretrained_model_name_or_path # composed models, *e.g.* TFRag, require special treatment when it comes to loading @@ -1978,16 +2325,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu else: model(model.dummy_inputs) # build the network with dummy inputs - assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}" # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( - model, - resolved_archive_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=load_weight_prefix, - ) + if is_sharded: + for file in resolved_archive_file: + os.path.isfile(file), f"Error retrieving files {file}" + + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + else: + missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) except OSError as e: try: with open(resolved_archive_file) as f: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index c793185ddf..d7795eba42 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -148,6 +148,7 @@ from .import_utils import ( WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" TF2_WEIGHTS_NAME = "tf_model.h5" +TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" CONFIG_NAME = "config.json" diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index a2ec88008e..dd1519d4a1 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -861,6 +861,7 @@ class PushToHubMixin: organization: Optional[str] = None, private: Optional[bool] = None, use_auth_token: Optional[Union[bool, str]] = None, + max_shard_size: Optional[Union[int, str]] = "10GB", **model_card_kwargs ) -> str: """ @@ -936,8 +937,9 @@ class PushToHubMixin: use_auth_token=use_auth_token, ) # Save the files in the cloned repo - self.save_pretrained(repo_path_or_name) + if hasattr(self, "history") and hasattr(self, "create_model_card"): + self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size) # This is a Keras model and we might be able to fish out its History and make a model card out of it base_model_card_args = { "output_dir": repo_path_or_name, @@ -945,6 +947,9 @@ class PushToHubMixin: } base_model_card_args.update(model_card_kwargs) self.create_model_card(**base_model_card_args) + else: + # FLAX does not support sharding yet, will come in next PR + self.save_pretrained(repo_path_or_name) # Commit and push! url = self._push_to_hub(repo, commit_message=commit_message) @@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"): except Exception: # We don't want to error in case of connection errors of any kind. pass + + +def convert_file_size_to_int(size: Union[int, str]): + """ + Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + + Args: + size (`int` or `str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> convert_file_size_to_int("1MiB") + 1048576 + ``` + """ + if isinstance(size, int): + return size + if size.upper().endswith("GIB"): + return int(size[:-3]) * (2**30) + if size.upper().endswith("MIB"): + return int(size[:-3]) * (2**20) + if size.upper().endswith("KIB"): + return int(size[:-3]) * (2**10) + if size.upper().endswith("GB"): + int_size = int(size[:-2]) * (10**9) + return int_size // 8 if size.endswith("b") else int_size + if size.upper().endswith("MB"): + int_size = int(size[:-2]) * (10**6) + return int_size // 8 if size.endswith("b") else int_size + if size.upper().endswith("KB"): + int_size = int(size[:-2]) * (10**3) + return int_size // 8 if size.endswith("b") else int_size + raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") + + +def get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_filename, + cache_dir=None, + force_download=False, + proxies=None, + resume_download=False, + local_files_only=False, + use_auth_token=None, + user_agent=None, + revision=None, + mirror=None, +): + """ + For a given model: + + - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the + Hub + - returns the list of paths to all the shards, as well as some metadata. + + For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the + index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). + """ + import json + + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + + with open(index_filename, "r") as f: + index = json.loads(f.read()) + + shard_filenames = sorted(list(set(index["weight_map"].values()))) + sharded_metadata = index["metadata"] + sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) + + # First, let's deal with local folder. + if os.path.isdir(pretrained_model_name_or_path): + shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] + return shard_filenames, sharded_metadata + + # At this stage pretrained_model_name_or_path is a model identifier on the Hub + cached_filenames = [] + for shard_filename in shard_filenames: + shard_url = hf_bucket_url( + pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror + ) + + try: + # Load from URL + cached_filename = cached_path( + shard_url, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is " + "required according to the checkpoint index." + ) + except HTTPError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try" + " again after checking your internet connection." + ) + + cached_filenames.append(cached_filename) + + return cached_filenames, sharded_metadata diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index b9f51a662c..843ddaa5e3 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -53,6 +53,7 @@ logger = logging.get_logger(__name__) if is_tf_available(): + import h5py import numpy as np import tensorflow as tf @@ -85,7 +86,12 @@ if is_tf_available(): TFSampleDecoderOnlyOutput, TFSampleEncoderDecoderOutput, ) - from transformers.modeling_tf_utils import unpack_inputs + from transformers.modeling_tf_utils import ( + TF2_WEIGHTS_INDEX_NAME, + TF2_WEIGHTS_NAME, + tf_shard_checkpoint, + unpack_inputs, + ) from transformers.tf_utils import stable_softmax if _tf_gpu_memory_limit is not None: @@ -1867,6 +1873,129 @@ class UtilsFunctionsTest(unittest.TestCase): out = masked_softmax(x, boolean_mask) assert tf.experimental.numpy.allclose(xla_out, out) + def test_checkpoint_sharding_from_hub(self): + model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") + # the model above is the same as the model below, just a sharded version. + ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + for p1, p2 in zip(model.weights, ref_model.weights): + assert np.allclose(p1.numpy(), p2.numpy()) + + def test_shard_checkpoint(self): + # This is the model we will use, total size 340,000 bytes. + model = tf.keras.Sequential( + [ + tf.keras.layers.Dense(200, use_bias=False), # size 80,000 + tf.keras.layers.Dense(200, use_bias=False), # size 160,000 + tf.keras.layers.Dense(100, use_bias=False), # size 80,000 + tf.keras.layers.Dense(50, use_bias=False), # size 20,000 + ] + ) + inputs = tf.zeros((1, 100), dtype=tf.float32) + model(inputs) + weights = model.weights + weights_dict = {w.name: w for w in weights} + with self.subTest("No shard when max size is bigger than model size"): + shards, index = tf_shard_checkpoint(weights) + self.assertIsNone(index) + self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights}) + + with self.subTest("Test sharding, no weights bigger than max size"): + shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB") + # Split is first two layers then last two. + self.assertDictEqual( + index, + { + "metadata": {"total_size": 340000}, + "weight_map": { + "dense/kernel:0": "tf_model-00001-of-00002.h5", + "dense_1/kernel:0": "tf_model-00001-of-00002.h5", + "dense_2/kernel:0": "tf_model-00002-of-00002.h5", + "dense_3/kernel:0": "tf_model-00002-of-00002.h5", + }, + }, + ) + + shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]] + shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]] + self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2}) + + with self.subTest("Test sharding with weights bigger than max size"): + shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB") + # Split is first layer, second layer then last 2. + self.assertDictEqual( + index, + { + "metadata": {"total_size": 340000}, + "weight_map": { + "dense/kernel:0": "tf_model-00001-of-00003.h5", + "dense_1/kernel:0": "tf_model-00002-of-00003.h5", + "dense_2/kernel:0": "tf_model-00003-of-00003.h5", + "dense_3/kernel:0": "tf_model-00003-of-00003.h5", + }, + }, + ) + + shard1 = [weights_dict["dense/kernel:0"]] + shard2 = [weights_dict["dense_1/kernel:0"]] + shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]] + self.assertDictEqual( + shards, + { + "tf_model-00001-of-00003.h5": shard1, + "tf_model-00002-of-00003.h5": shard2, + "tf_model-00003-of-00003.h5": shard3, + }, + ) + + def test_checkpoint_sharding_local(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + with tempfile.TemporaryDirectory() as tmp_dir: + # We use the same folder for various sizes to make sure a new save erases the old checkpoint. + for max_size in ["150kB", "150kiB", "200kB", "200kiB"]: + model.save_pretrained(tmp_dir, max_shard_size=max_size) + + # Get each shard file and its size + shard_to_size = {} + for shard in os.listdir(tmp_dir): + if shard.endswith(".h5"): + shard_file = os.path.join(tmp_dir, shard) + shard_to_size[shard_file] = os.path.getsize(shard_file) + + index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME) + # Check there is an index but no regular weight file + self.assertTrue(os.path.isfile(index_file)) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) + + # Check a file is bigger than max_size only when it has a single weight + for shard_file, size in shard_to_size.items(): + if max_size.endswith("kiB"): + max_size_int = int(max_size[:-3]) * 2**10 + else: + max_size_int = int(max_size[:-2]) * 10**3 + # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than + # the size asked for (since we count parameters) + if size >= max_size_int + 50000: + with h5py.File(shard_file, "r") as state_file: + self.assertEqual(len(state_file), 1) + + # Check the index and the shard files found match + with open(index_file, "r", encoding="utf-8") as f: + index = json.loads(f.read()) + + all_shards = set(index["weight_map"].values()) + shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5")) + self.assertSetEqual(all_shards, shards_found) + + # Finally, check the model can be reloaded + new_model = TFBertModel.from_pretrained(tmp_dir) + + model(model.dummy_inputs) + new_model(model.dummy_inputs) + + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @require_tf @is_staging_test