TF Sharded (#17713)
* initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,9 @@
|
||||
"""TF general model utils."""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
@@ -33,7 +35,9 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from huggingface_hub import Repository, list_repo_files
|
||||
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
|
||||
from requests import HTTPError
|
||||
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||
|
||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||
from .activations_tf import get_tf_activation
|
||||
@@ -44,6 +48,7 @@ from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
@@ -554,9 +559,243 @@ def input_processing(func, config, input_ids, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(tf.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == tf.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search("[^\d](\d+)$", dtype.name)
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
||||
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
|
||||
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
|
||||
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
|
||||
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
|
||||
have a size greater than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||
(like `"5MB"`).
|
||||
"""
|
||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||
|
||||
sharded_state_dicts = []
|
||||
current_block = []
|
||||
current_block_size = 0
|
||||
total_size = 0
|
||||
|
||||
for item in weights:
|
||||
weight_size = item.numpy().size * dtype_byte_size(item.dtype)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
current_block = []
|
||||
current_block_size = 0
|
||||
|
||||
current_block.append(item)
|
||||
current_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
||||
shards[shard_file] = shard
|
||||
for weight in shard:
|
||||
weight_name = weight.name
|
||||
weight_map[weight_name] = shard_file
|
||||
|
||||
# Add the metadata
|
||||
metadata = {"total_size": total_size}
|
||||
index = {"metadata": metadata, "weight_map": weight_map}
|
||||
return shards, index
|
||||
|
||||
|
||||
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
|
||||
"""
|
||||
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
|
||||
the TF weights from the shard file accordingly to their names and shapes.
|
||||
|
||||
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||
loaded in the model.
|
||||
|
||||
Args:
|
||||
model (`tf.keras.models.Model`): The model in which to load the checkpoint.
|
||||
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
|
||||
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
|
||||
Whether or not to ignore the mismatch between the sizes
|
||||
strict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
|
||||
Returns:
|
||||
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||
mismatched layers.
|
||||
"""
|
||||
|
||||
# Load the index
|
||||
missing_keys = []
|
||||
unexpected_keys = set()
|
||||
saved_keys = set()
|
||||
missmatched_keys = set()
|
||||
|
||||
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
||||
# the weight, we have to get rid of the first prefix of the name of the layer.
|
||||
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
|
||||
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = tf.io.read_file(shard_file)
|
||||
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
|
||||
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
|
||||
)
|
||||
saved_keys.update(saved_weight_names_set)
|
||||
unexpected_keys.update(unexpected_keys_set)
|
||||
missmatched_keys.update(missmatched_keys_set)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
missing_keys = model_keys - saved_keys
|
||||
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||
if len(missing_keys) > 0:
|
||||
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||
if len(unexpected_keys) > 0:
|
||||
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
return missing_keys, unexpected_keys, missmatched_keys
|
||||
|
||||
|
||||
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
|
||||
"""
|
||||
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
|
||||
|
||||
Args:
|
||||
model (`tf.keras.models.Model`): Model in which the weights are loaded
|
||||
model_layer_map (`Dict`): A dictionnary mapping the layer name to the index of the layer in the model.
|
||||
resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
|
||||
|
||||
Returns:
|
||||
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
|
||||
shard file), one for the missmatched layers, and another one for the unexpected layers.
|
||||
"""
|
||||
saved_weight_names_set = set()
|
||||
saved_weights = {}
|
||||
missmatched_keys = set()
|
||||
unexpected_keys = set()
|
||||
# Read the H5 file
|
||||
try:
|
||||
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||
# Retrieve the name of each layer from the H5 file
|
||||
saved_h5_model_layers_name = set(
|
||||
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
||||
)
|
||||
weight_value_tuples = []
|
||||
|
||||
# Compute missing and unexpected sub layers
|
||||
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
|
||||
for layer_name in saved_h5_model_layers_name:
|
||||
h5_layer_object = sharded_checkpoint_file[layer_name]
|
||||
saved_weights[layer_name] = np.asarray(h5_layer_object)
|
||||
|
||||
saved_weight_names_set.add(layer_name)
|
||||
|
||||
if layer_name not in model_layer_map:
|
||||
unexpected_keys.add(layer_name)
|
||||
else:
|
||||
symbolic_weight = model.weights[model_layer_map[layer_name]]
|
||||
|
||||
saved_weight_value = saved_weights[layer_name]
|
||||
# If the current weight is found
|
||||
if saved_weight_value is not None:
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
missmatched_keys.add(
|
||||
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
array = saved_weight_value
|
||||
|
||||
# We create the tuple that will be loaded and add it to the final list
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
return saved_weight_names_set, unexpected_keys, missmatched_keys
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
|
||||
" model. Make sure you have saved the model properly."
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
|
||||
f"at '{resolved_archive_file}'. "
|
||||
"If you tried to load a TF model from a sharded checkpoint, you should try converting the model"
|
||||
"by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
|
||||
)
|
||||
|
||||
|
||||
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
"""
|
||||
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
||||
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||
shapes.
|
||||
|
||||
Args:
|
||||
model (`tf.keras.models.Model`):
|
||||
@@ -575,9 +814,11 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
||||
mismatched_layers = []
|
||||
|
||||
# Read the H5 file
|
||||
with h5py.File(resolved_archive_file, "r") as f:
|
||||
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||
# Retrieve the name of each layer from the H5 file
|
||||
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
|
||||
saved_h5_model_layers_name = set(
|
||||
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
||||
)
|
||||
|
||||
# Find the missing layers from the high level list of layers
|
||||
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
|
||||
@@ -594,7 +835,7 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
||||
# if layer_name from the H5 file belongs to the layers from the instantiated model
|
||||
if layer.name in saved_h5_model_layers_name:
|
||||
# Get the H5 layer object from its name
|
||||
h5_layer_object = f[layer.name]
|
||||
h5_layer_object = sharded_checkpoint_file[layer.name]
|
||||
# Get all the weights as a list from the layer object
|
||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||
saved_weights = {}
|
||||
@@ -1624,7 +1865,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory,
|
||||
saved_model=False,
|
||||
version=1,
|
||||
push_to_hub=False,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
[`~TFPreTrainedModel.from_pretrained`] class method.
|
||||
@@ -1649,6 +1898,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
</Tip>
|
||||
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||
which will be bigger than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -1679,8 +1939,48 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
|
||||
|
||||
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
if (
|
||||
filename.startswith(TF2_WEIGHTS_NAME[:-4])
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in shards.keys()
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
if index is None:
|
||||
self.save_weights(output_model_file)
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
else:
|
||||
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as index_file:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
index_file.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
for shard_file, shard in shards.items():
|
||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||
save_attributes_to_hdf5_group(
|
||||
shard_file,
|
||||
"layer_names",
|
||||
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
|
||||
)
|
||||
|
||||
for layer in sorted(shard, key=lambda x: x.name):
|
||||
param_dset = shard_file.create_dataset(
|
||||
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
|
||||
)
|
||||
param_dset[:] = layer.numpy()
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
@@ -1844,6 +2144,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
is_sharded = False
|
||||
sharded_metadata = None
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
@@ -1853,6 +2157,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
@@ -1906,6 +2214,28 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if filename == TF2_WEIGHTS_NAME:
|
||||
try:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=TF2_WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
is_sharded = True
|
||||
except EntryNotFoundError:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
@@ -1914,14 +2244,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
||||
"those weights."
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||
" load this model from those weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||
f"or {WEIGHTS_NAME}."
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
@@ -1955,6 +2285,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
|
||||
@@ -1978,10 +2325,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
else:
|
||||
model(model.dummy_inputs) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}"
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
try:
|
||||
if is_sharded:
|
||||
for file in resolved_archive_file:
|
||||
os.path.isfile(file), f"Error retrieving files {file}"
|
||||
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
else:
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
|
||||
@@ -148,6 +148,7 @@ from .import_utils import (
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
@@ -861,6 +861,7 @@ class PushToHubMixin:
|
||||
organization: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||
**model_card_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@@ -936,8 +937,9 @@ class PushToHubMixin:
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
# Save the files in the cloned repo
|
||||
self.save_pretrained(repo_path_or_name)
|
||||
|
||||
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
||||
base_model_card_args = {
|
||||
"output_dir": repo_path_or_name,
|
||||
@@ -945,6 +947,9 @@ class PushToHubMixin:
|
||||
}
|
||||
base_model_card_args.update(model_card_kwargs)
|
||||
self.create_model_card(**base_model_card_args)
|
||||
else:
|
||||
# FLAX does not support sharding yet, will come in next PR
|
||||
self.save_pretrained(repo_path_or_name)
|
||||
# Commit and push!
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
|
||||
@@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
|
||||
except Exception:
|
||||
# We don't want to error in case of connection errors of any kind.
|
||||
pass
|
||||
|
||||
|
||||
def convert_file_size_to_int(size: Union[int, str]):
|
||||
"""
|
||||
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||
|
||||
Args:
|
||||
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> convert_file_size_to_int("1MiB")
|
||||
1048576
|
||||
```
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
return size
|
||||
if size.upper().endswith("GIB"):
|
||||
return int(size[:-3]) * (2**30)
|
||||
if size.upper().endswith("MIB"):
|
||||
return int(size[:-3]) * (2**20)
|
||||
if size.upper().endswith("KIB"):
|
||||
return int(size[:-3]) * (2**10)
|
||||
if size.upper().endswith("GB"):
|
||||
int_size = int(size[:-2]) * (10**9)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
if size.upper().endswith("MB"):
|
||||
int_size = int(size[:-2]) * (10**6)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
if size.upper().endswith("KB"):
|
||||
int_size = int(size[:-2]) * (10**3)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
||||
|
||||
|
||||
def get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_filename,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
local_files_only=False,
|
||||
use_auth_token=None,
|
||||
user_agent=None,
|
||||
revision=None,
|
||||
mirror=None,
|
||||
):
|
||||
"""
|
||||
For a given model:
|
||||
|
||||
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
|
||||
Hub
|
||||
- returns the list of paths to all the shards, as well as some metadata.
|
||||
|
||||
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
||||
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
||||
"""
|
||||
import json
|
||||
|
||||
if not os.path.isfile(index_filename):
|
||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
shard_filenames = sorted(list(set(index["weight_map"].values())))
|
||||
sharded_metadata = index["metadata"]
|
||||
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||
|
||||
# First, let's deal with local folder.
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
cached_filenames = []
|
||||
for shard_filename in shard_filenames:
|
||||
shard_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
cached_filename = cached_path(
|
||||
shard_url,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here.
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
except HTTPError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
|
||||
" again after checking your internet connection."
|
||||
)
|
||||
|
||||
cached_filenames.append(cached_filename)
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
@@ -53,6 +53,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -85,7 +86,12 @@ if is_tf_available():
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import unpack_inputs
|
||||
from transformers.modeling_tf_utils import (
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
tf_shard_checkpoint,
|
||||
unpack_inputs,
|
||||
)
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
@@ -1867,6 +1873,129 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
out = masked_softmax(x, boolean_mask)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
|
||||
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
|
||||
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
|
||||
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
|
||||
]
|
||||
)
|
||||
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
||||
model(inputs)
|
||||
weights = model.weights
|
||||
weights_dict = {w.name: w for w in weights}
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = tf_shard_checkpoint(weights)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
||||
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
||||
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_1/kernel:0"]]
|
||||
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"tf_model-00001-of-00003.h5": shard1,
|
||||
"tf_model-00002-of-00003.h5": shard2,
|
||||
"tf_model-00003-of-00003.h5": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
with h5py.File(shard_file, "r") as state_file:
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model(model.dummy_inputs)
|
||||
new_model(model.dummy_inputs)
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user