TF Sharded (#17713)
* initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,9 @@
|
|||||||
"""TF general model utils."""
|
"""TF general model utils."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
@@ -33,7 +35,9 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
|||||||
from tensorflow.python.keras.saving import hdf5_format
|
from tensorflow.python.keras.saving import hdf5_format
|
||||||
|
|
||||||
from huggingface_hub import Repository, list_repo_files
|
from huggingface_hub import Repository, list_repo_files
|
||||||
|
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
|
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
|
||||||
from . import DataCollatorWithPadding, DefaultDataCollator
|
from . import DataCollatorWithPadding, DefaultDataCollator
|
||||||
from .activations_tf import get_tf_activation
|
from .activations_tf import get_tf_activation
|
||||||
@@ -44,6 +48,7 @@ from .tf_utils import shape_list
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
@@ -554,9 +559,243 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_byte_size(dtype):
|
||||||
|
"""
|
||||||
|
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> dtype_byte_size(tf.float32)
|
||||||
|
4
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dtype == tf.bool:
|
||||||
|
return 1 / 8
|
||||||
|
bit_search = re.search("[^\d](\d+)$", dtype.name)
|
||||||
|
if bit_search is None:
|
||||||
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||||
|
bit_size = int(bit_search.groups()[0])
|
||||||
|
return bit_size // 8
|
||||||
|
|
||||||
|
|
||||||
|
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
||||||
|
"""
|
||||||
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size.
|
||||||
|
|
||||||
|
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
|
||||||
|
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
|
||||||
|
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
|
||||||
|
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
|
||||||
|
have a size greater than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (`Dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||||
|
(like `"5MB"`).
|
||||||
|
"""
|
||||||
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||||
|
|
||||||
|
sharded_state_dicts = []
|
||||||
|
current_block = []
|
||||||
|
current_block_size = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for item in weights:
|
||||||
|
weight_size = item.numpy().size * dtype_byte_size(item.dtype)
|
||||||
|
|
||||||
|
# If this weight is going to tip up over the maximal size, we split.
|
||||||
|
if current_block_size + weight_size > max_shard_size:
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
current_block = []
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block.append(item)
|
||||||
|
current_block_size += weight_size
|
||||||
|
total_size += weight_size
|
||||||
|
|
||||||
|
# Add the last block
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
|
||||||
|
# If we only have one shard, we return it
|
||||||
|
if len(sharded_state_dicts) == 1:
|
||||||
|
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
|
# Otherwise, let's build the index
|
||||||
|
weight_map = {}
|
||||||
|
shards = {}
|
||||||
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
|
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
||||||
|
shards[shard_file] = shard
|
||||||
|
for weight in shard:
|
||||||
|
weight_name = weight.name
|
||||||
|
weight_map[weight_name] = shard_file
|
||||||
|
|
||||||
|
# Add the metadata
|
||||||
|
metadata = {"total_size": total_size}
|
||||||
|
index = {"metadata": metadata, "weight_map": weight_map}
|
||||||
|
return shards, index
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=True):
|
||||||
|
"""
|
||||||
|
This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
|
||||||
|
the TF weights from the shard file accordingly to their names and shapes.
|
||||||
|
|
||||||
|
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||||
|
loaded in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`tf.keras.models.Model`): The model in which to load the checkpoint.
|
||||||
|
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
|
||||||
|
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
|
||||||
|
Whether or not to ignore the mismatch between the sizes
|
||||||
|
strict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||||
|
mismatched layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
missing_keys = []
|
||||||
|
unexpected_keys = set()
|
||||||
|
saved_keys = set()
|
||||||
|
missmatched_keys = set()
|
||||||
|
|
||||||
|
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
|
||||||
|
# the weight, we have to get rid of the first prefix of the name of the layer.
|
||||||
|
model_keys = set("/".join(k.name.split("/")[1:]) for k in model.weights)
|
||||||
|
model_layer_map = {"/".join(k.name.split("/")[1:]): i for i, k in enumerate(model.weights)}
|
||||||
|
|
||||||
|
for shard_file in shard_files:
|
||||||
|
state_dict = tf.io.read_file(shard_file)
|
||||||
|
saved_weight_names_set, unexpected_keys_set, missmatched_keys_set = load_tf_shard(
|
||||||
|
model, model_layer_map, shard_file, ignore_mismatched_sizes=ignore_mismatched_sizes
|
||||||
|
)
|
||||||
|
saved_keys.update(saved_weight_names_set)
|
||||||
|
unexpected_keys.update(unexpected_keys_set)
|
||||||
|
missmatched_keys.update(missmatched_keys_set)
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
missing_keys = model_keys - saved_keys
|
||||||
|
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||||
|
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||||
|
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||||
|
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||||
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
|
return missing_keys, unexpected_keys, missmatched_keys
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False):
|
||||||
|
"""
|
||||||
|
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`tf.keras.models.Model`): Model in which the weights are loaded
|
||||||
|
model_layer_map (`Dict`): A dictionnary mapping the layer name to the index of the layer in the model.
|
||||||
|
resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
|
||||||
|
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tf.keras.models.Model`: Three lists, one for the layers that were found and succesfully restored (from the
|
||||||
|
shard file), one for the missmatched layers, and another one for the unexpected layers.
|
||||||
|
"""
|
||||||
|
saved_weight_names_set = set()
|
||||||
|
saved_weights = {}
|
||||||
|
missmatched_keys = set()
|
||||||
|
unexpected_keys = set()
|
||||||
|
# Read the H5 file
|
||||||
|
try:
|
||||||
|
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||||
|
# Retrieve the name of each layer from the H5 file
|
||||||
|
saved_h5_model_layers_name = set(
|
||||||
|
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
||||||
|
)
|
||||||
|
weight_value_tuples = []
|
||||||
|
|
||||||
|
# Compute missing and unexpected sub layers
|
||||||
|
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
|
||||||
|
for layer_name in saved_h5_model_layers_name:
|
||||||
|
h5_layer_object = sharded_checkpoint_file[layer_name]
|
||||||
|
saved_weights[layer_name] = np.asarray(h5_layer_object)
|
||||||
|
|
||||||
|
saved_weight_names_set.add(layer_name)
|
||||||
|
|
||||||
|
if layer_name not in model_layer_map:
|
||||||
|
unexpected_keys.add(layer_name)
|
||||||
|
else:
|
||||||
|
symbolic_weight = model.weights[model_layer_map[layer_name]]
|
||||||
|
|
||||||
|
saved_weight_value = saved_weights[layer_name]
|
||||||
|
# If the current weight is found
|
||||||
|
if saved_weight_value is not None:
|
||||||
|
# Check if the shape of the current weight and the one from the H5 file are different
|
||||||
|
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
|
||||||
|
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||||
|
# If the two shapes are not compatible we raise an issue
|
||||||
|
try:
|
||||||
|
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
|
||||||
|
except ValueError as e:
|
||||||
|
if ignore_mismatched_sizes:
|
||||||
|
missmatched_keys.add(
|
||||||
|
(layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
array = saved_weight_value
|
||||||
|
|
||||||
|
# We create the tuple that will be loaded and add it to the final list
|
||||||
|
weight_value_tuples.append((symbolic_weight, array))
|
||||||
|
|
||||||
|
K.batch_set_value(weight_value_tuples)
|
||||||
|
|
||||||
|
return saved_weight_names_set, unexpected_keys, missmatched_keys
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
with open(resolved_archive_file) as f:
|
||||||
|
if f.read().startswith("version"):
|
||||||
|
raise OSError(
|
||||||
|
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||||
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||||
|
"you cloned."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
|
||||||
|
" model. Make sure you have saved the model properly."
|
||||||
|
) from e
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise OSError(
|
||||||
|
f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
|
||||||
|
f"at '{resolved_archive_file}'. "
|
||||||
|
"If you tried to load a TF model from a sharded checkpoint, you should try converting the model"
|
||||||
|
"by loading it in pytorch and saving it localy. A convertion script should be realeased soon."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||||
"""
|
"""
|
||||||
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
|
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||||
|
shapes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`tf.keras.models.Model`):
|
model (`tf.keras.models.Model`):
|
||||||
@@ -575,9 +814,11 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
|||||||
mismatched_layers = []
|
mismatched_layers = []
|
||||||
|
|
||||||
# Read the H5 file
|
# Read the H5 file
|
||||||
with h5py.File(resolved_archive_file, "r") as f:
|
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
|
||||||
# Retrieve the name of each layer from the H5 file
|
# Retrieve the name of each layer from the H5 file
|
||||||
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
|
saved_h5_model_layers_name = set(
|
||||||
|
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
|
||||||
|
)
|
||||||
|
|
||||||
# Find the missing layers from the high level list of layers
|
# Find the missing layers from the high level list of layers
|
||||||
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
|
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)
|
||||||
@@ -594,7 +835,7 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
|||||||
# if layer_name from the H5 file belongs to the layers from the instantiated model
|
# if layer_name from the H5 file belongs to the layers from the instantiated model
|
||||||
if layer.name in saved_h5_model_layers_name:
|
if layer.name in saved_h5_model_layers_name:
|
||||||
# Get the H5 layer object from its name
|
# Get the H5 layer object from its name
|
||||||
h5_layer_object = f[layer.name]
|
h5_layer_object = sharded_checkpoint_file[layer.name]
|
||||||
# Get all the weights as a list from the layer object
|
# Get all the weights as a list from the layer object
|
||||||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
|
||||||
saved_weights = {}
|
saved_weights = {}
|
||||||
@@ -1624,7 +1865,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_pretrained(self, save_directory, saved_model=False, version=1, push_to_hub=False, **kwargs):
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory,
|
||||||
|
saved_model=False,
|
||||||
|
version=1,
|
||||||
|
push_to_hub=False,
|
||||||
|
max_shard_size: Union[int, str] = "10GB",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
[`~TFPreTrainedModel.from_pretrained`] class method.
|
[`~TFPreTrainedModel.from_pretrained`] class method.
|
||||||
@@ -1649,6 +1898,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||||
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||||
|
which will be bigger than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
@@ -1679,8 +1939,48 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# If we save using the predefined names, we can load using `from_pretrained`
|
||||||
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
|
||||||
|
|
||||||
|
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
|
||||||
|
|
||||||
|
# Clean the folder from a previous save
|
||||||
|
for filename in os.listdir(save_directory):
|
||||||
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||||
|
# in distributed settings to avoid race conditions.
|
||||||
|
if (
|
||||||
|
filename.startswith(TF2_WEIGHTS_NAME[:-4])
|
||||||
|
and os.path.isfile(full_filename)
|
||||||
|
and filename not in shards.keys()
|
||||||
|
):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
self.save_weights(output_model_file)
|
self.save_weights(output_model_file)
|
||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
|
else:
|
||||||
|
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
|
||||||
|
# Save the index as well
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as index_file:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
index_file.write(content)
|
||||||
|
logger.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||||
|
save_attributes_to_hdf5_group(
|
||||||
|
shard_file,
|
||||||
|
"layer_names",
|
||||||
|
["/".join(layer.name.split("/")[1:]).encode("utf8") for layer in shard],
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer in sorted(shard, key=lambda x: x.name):
|
||||||
|
param_dset = shard_file.create_dataset(
|
||||||
|
"/".join(layer.name.split("/")[1:]), layer.numpy().shape, dtype=layer.numpy().dtype
|
||||||
|
)
|
||||||
|
param_dset[:] = layer.numpy()
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
@@ -1844,6 +2144,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
|
# index of the files.
|
||||||
|
is_sharded = False
|
||||||
|
sharded_metadata = None
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
@@ -1853,6 +2157,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
# Load from a TF 2.0 checkpoint
|
# Load from a TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
|
||||||
|
# Load from a sharded PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
@@ -1906,6 +2214,28 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
)
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
if filename == TF2_WEIGHTS_NAME:
|
if filename == TF2_WEIGHTS_NAME:
|
||||||
|
try:
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
archive_file = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=TF2_WEIGHTS_INDEX_NAME,
|
||||||
|
revision=revision,
|
||||||
|
mirror=mirror,
|
||||||
|
)
|
||||||
|
resolved_archive_file = cached_path(
|
||||||
|
archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
is_sharded = True
|
||||||
|
except EntryNotFoundError:
|
||||||
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
|
# message.
|
||||||
has_file_kwargs = {
|
has_file_kwargs = {
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
"mirror": mirror,
|
"mirror": mirror,
|
||||||
@@ -1914,14 +2244,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
}
|
}
|
||||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||||
"those weights."
|
" load this model from those weights."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f"or {WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
@@ -1955,6 +2285,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
|
if is_sharded:
|
||||||
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
mirror=mirror,
|
||||||
|
)
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
|
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
|
||||||
@@ -1978,10 +2325,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
else:
|
else:
|
||||||
model(model.dummy_inputs) # build the network with dummy inputs
|
model(model.dummy_inputs) # build the network with dummy inputs
|
||||||
|
|
||||||
assert os.path.isfile(resolved_archive_file), f"Error retrieving file {resolved_archive_file}"
|
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
try:
|
try:
|
||||||
|
if is_sharded:
|
||||||
|
for file in resolved_archive_file:
|
||||||
|
os.path.isfile(file), f"Error retrieving files {file}"
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
||||||
|
model,
|
||||||
|
resolved_archive_file,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||||
model,
|
model,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ from .import_utils import (
|
|||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
|
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
||||||
TF_WEIGHTS_NAME = "model.ckpt"
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
|
|||||||
@@ -861,6 +861,7 @@ class PushToHubMixin:
|
|||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
private: Optional[bool] = None,
|
private: Optional[bool] = None,
|
||||||
use_auth_token: Optional[Union[bool, str]] = None,
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
max_shard_size: Optional[Union[int, str]] = "10GB",
|
||||||
**model_card_kwargs
|
**model_card_kwargs
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -936,8 +937,9 @@ class PushToHubMixin:
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
# Save the files in the cloned repo
|
# Save the files in the cloned repo
|
||||||
self.save_pretrained(repo_path_or_name)
|
|
||||||
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
||||||
|
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||||
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
||||||
base_model_card_args = {
|
base_model_card_args = {
|
||||||
"output_dir": repo_path_or_name,
|
"output_dir": repo_path_or_name,
|
||||||
@@ -945,6 +947,9 @@ class PushToHubMixin:
|
|||||||
}
|
}
|
||||||
base_model_card_args.update(model_card_kwargs)
|
base_model_card_args.update(model_card_kwargs)
|
||||||
self.create_model_card(**base_model_card_args)
|
self.create_model_card(**base_model_card_args)
|
||||||
|
else:
|
||||||
|
# FLAX does not support sharding yet, will come in next PR
|
||||||
|
self.save_pretrained(repo_path_or_name)
|
||||||
# Commit and push!
|
# Commit and push!
|
||||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
|
||||||
@@ -1075,3 +1080,114 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
|
|||||||
except Exception:
|
except Exception:
|
||||||
# We don't want to error in case of connection errors of any kind.
|
# We don't want to error in case of connection errors of any kind.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file_size_to_int(size: Union[int, str]):
|
||||||
|
"""
|
||||||
|
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> convert_file_size_to_int("1MiB")
|
||||||
|
1048576
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(size, int):
|
||||||
|
return size
|
||||||
|
if size.upper().endswith("GIB"):
|
||||||
|
return int(size[:-3]) * (2**30)
|
||||||
|
if size.upper().endswith("MIB"):
|
||||||
|
return int(size[:-3]) * (2**20)
|
||||||
|
if size.upper().endswith("KIB"):
|
||||||
|
return int(size[:-3]) * (2**10)
|
||||||
|
if size.upper().endswith("GB"):
|
||||||
|
int_size = int(size[:-2]) * (10**9)
|
||||||
|
return int_size // 8 if size.endswith("b") else int_size
|
||||||
|
if size.upper().endswith("MB"):
|
||||||
|
int_size = int(size[:-2]) * (10**6)
|
||||||
|
return int_size // 8 if size.endswith("b") else int_size
|
||||||
|
if size.upper().endswith("KB"):
|
||||||
|
int_size = int(size[:-2]) * (10**3)
|
||||||
|
return int_size // 8 if size.endswith("b") else int_size
|
||||||
|
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
index_filename,
|
||||||
|
cache_dir=None,
|
||||||
|
force_download=False,
|
||||||
|
proxies=None,
|
||||||
|
resume_download=False,
|
||||||
|
local_files_only=False,
|
||||||
|
use_auth_token=None,
|
||||||
|
user_agent=None,
|
||||||
|
revision=None,
|
||||||
|
mirror=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For a given model:
|
||||||
|
|
||||||
|
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
|
||||||
|
Hub
|
||||||
|
- returns the list of paths to all the shards, as well as some metadata.
|
||||||
|
|
||||||
|
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
||||||
|
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
if not os.path.isfile(index_filename):
|
||||||
|
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||||
|
|
||||||
|
with open(index_filename, "r") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
shard_filenames = sorted(list(set(index["weight_map"].values())))
|
||||||
|
sharded_metadata = index["metadata"]
|
||||||
|
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||||
|
|
||||||
|
# First, let's deal with local folder.
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
||||||
|
return shard_filenames, sharded_metadata
|
||||||
|
|
||||||
|
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||||
|
cached_filenames = []
|
||||||
|
for shard_filename in shard_filenames:
|
||||||
|
shard_url = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load from URL
|
||||||
|
cached_filename = cached_path(
|
||||||
|
shard_url,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||||
|
# we don't have to catch them here.
|
||||||
|
except EntryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
||||||
|
"required according to the checkpoint index."
|
||||||
|
)
|
||||||
|
except HTTPError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
|
||||||
|
" again after checking your internet connection."
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_filenames.append(cached_filename)
|
||||||
|
|
||||||
|
return cached_filenames, sharded_metadata
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -85,7 +86,12 @@ if is_tf_available():
|
|||||||
TFSampleDecoderOnlyOutput,
|
TFSampleDecoderOnlyOutput,
|
||||||
TFSampleEncoderDecoderOutput,
|
TFSampleEncoderDecoderOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_tf_utils import unpack_inputs
|
from transformers.modeling_tf_utils import (
|
||||||
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
tf_shard_checkpoint,
|
||||||
|
unpack_inputs,
|
||||||
|
)
|
||||||
from transformers.tf_utils import stable_softmax
|
from transformers.tf_utils import stable_softmax
|
||||||
|
|
||||||
if _tf_gpu_memory_limit is not None:
|
if _tf_gpu_memory_limit is not None:
|
||||||
@@ -1867,6 +1873,129 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
out = masked_softmax(x, boolean_mask)
|
out = masked_softmax(x, boolean_mask)
|
||||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_from_hub(self):
|
||||||
|
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
|
# the model above is the same as the model below, just a sharded version.
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
def test_shard_checkpoint(self):
|
||||||
|
# This is the model we will use, total size 340,000 bytes.
|
||||||
|
model = tf.keras.Sequential(
|
||||||
|
[
|
||||||
|
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
|
||||||
|
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
|
||||||
|
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
|
||||||
|
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
|
||||||
|
]
|
||||||
|
)
|
||||||
|
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
||||||
|
model(inputs)
|
||||||
|
weights = model.weights
|
||||||
|
weights_dict = {w.name: w for w in weights}
|
||||||
|
with self.subTest("No shard when max size is bigger than model size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights)
|
||||||
|
self.assertIsNone(index)
|
||||||
|
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
||||||
|
# Split is first two layers then last two.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
||||||
|
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
||||||
|
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
||||||
|
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
||||||
|
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||||
|
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding with weights bigger than max size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
||||||
|
# Split is first layer, second layer then last 2.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
||||||
|
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
||||||
|
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
||||||
|
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = [weights_dict["dense/kernel:0"]]
|
||||||
|
shard2 = [weights_dict["dense_1/kernel:0"]]
|
||||||
|
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards,
|
||||||
|
{
|
||||||
|
"tf_model-00001-of-00003.h5": shard1,
|
||||||
|
"tf_model-00002-of-00003.h5": shard2,
|
||||||
|
"tf_model-00003-of-00003.h5": shard3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_local(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".h5"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
|
for shard_file, size in shard_to_size.items():
|
||||||
|
if max_size.endswith("kiB"):
|
||||||
|
max_size_int = int(max_size[:-3]) * 2**10
|
||||||
|
else:
|
||||||
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
|
# the size asked for (since we count parameters)
|
||||||
|
if size >= max_size_int + 50000:
|
||||||
|
with h5py.File(shard_file, "r") as state_file:
|
||||||
|
self.assertEqual(len(state_file), 1)
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model(model.dummy_inputs)
|
||||||
|
new_model(model.dummy_inputs)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user