Safetensors tf (#19900)
* Wip * Add safetensors support for TensorFlow * First tests * Add final test for now * Retrigger CI like this * Update src/transformers/modeling_tf_utils.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -21,7 +21,7 @@ import re
|
||||
|
||||
import numpy
|
||||
|
||||
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze
|
||||
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
|
||||
from .utils import transpose as transpose_func
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
|
||||
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
||||
|
||||
tf_loaded_numel += array.size
|
||||
tf_loaded_numel += tensor_size(array)
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
all_pytorch_weights.discard(name)
|
||||
|
||||
@@ -47,6 +47,8 @@ from .generation_tf_utils import TFGenerationMixin
|
||||
from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@@ -59,12 +61,18 @@ from .utils import (
|
||||
has_file,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_safetensors_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
working_or_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.tensorflow import load_file as safe_load_file
|
||||
from safetensors.tensorflow import save_file as safe_save_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PreTrainedTokenizerBase
|
||||
|
||||
@@ -612,6 +620,14 @@ def dtype_byte_size(dtype):
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def format_weight_name(name, _prefix=None):
|
||||
if "model." not in name and len(name.split("/")) > 1:
|
||||
name = "/".join(name.split("/")[1:])
|
||||
if _prefix is not None:
|
||||
name = _prefix + "/" + name
|
||||
return name
|
||||
|
||||
|
||||
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
@@ -849,6 +865,17 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
||||
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||
mismatched layers.
|
||||
"""
|
||||
if resolved_archive_file.endswith(".safetensors"):
|
||||
load_function = load_tf_weights_from_safetensors
|
||||
else:
|
||||
load_function = load_tf_weights_from_h5
|
||||
|
||||
return load_function(
|
||||
model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
|
||||
)
|
||||
|
||||
|
||||
def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
missing_layers = []
|
||||
unexpected_layers = []
|
||||
mismatched_layers = []
|
||||
@@ -952,6 +979,47 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False,
|
||||
return missing_layers, unexpected_layers, mismatched_layers
|
||||
|
||||
|
||||
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
# Read the safetensors file
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
|
||||
weight_value_tuples = []
|
||||
mismatched_layers = []
|
||||
|
||||
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
|
||||
loaded_weight_names = list(state_dict.keys())
|
||||
|
||||
# Find the missing layers from the high level list of layers
|
||||
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
||||
# Find the unexpected layers from the high level list of layers
|
||||
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
||||
|
||||
weight_value_tuples = []
|
||||
for weight in model.weights:
|
||||
weight_name = format_weight_name(weight.name, _prefix=_prefix)
|
||||
if weight_name in state_dict:
|
||||
weight_value = state_dict[weight_name]
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(weight) != weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
weight_value_tuples.append((weight, weight_value))
|
||||
|
||||
# Load all the weights
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
return missing_layers, unexpected_layers, mismatched_layers
|
||||
|
||||
|
||||
def init_copy_embeddings(old_embeddings, new_num_tokens):
|
||||
r"""
|
||||
This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
|
||||
@@ -2118,6 +2186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
signatures=None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
create_pr: bool = False,
|
||||
safe_serialization: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -2152,6 +2221,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
@@ -2186,7 +2257,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
self.config.save_pretrained(save_directory)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
||||
output_model_file = os.path.join(save_directory, weights_name)
|
||||
|
||||
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
|
||||
|
||||
@@ -2195,15 +2267,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
if (
|
||||
filename.startswith(TF2_WEIGHTS_NAME[:-4])
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in shards.keys()
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
if index is None:
|
||||
self.save_weights(output_model_file)
|
||||
if safe_serialization:
|
||||
state_dict = {format_weight_name(w.name): w.value() for w in self.weights}
|
||||
safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
|
||||
else:
|
||||
self.save_weights(output_model_file)
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
else:
|
||||
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
|
||||
@@ -2427,6 +2504,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
@@ -2457,7 +2546,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
# set correct filename
|
||||
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||
if from_pt:
|
||||
filename = WEIGHTS_NAME
|
||||
elif is_safetensors_available():
|
||||
filename = SAFE_WEIGHTS_NAME
|
||||
else:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
@@ -2476,8 +2570,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
raise NotImplementedError(
|
||||
"Support for sharded checkpoints using safetensors is coming soon!"
|
||||
)
|
||||
else:
|
||||
# This repo has no safetensors file of any kind, we switch to TensorFlow.
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
@@ -2521,6 +2631,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
if is_local:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
resolved_archive_file = archive_file
|
||||
filename = resolved_archive_file.split(os.path.sep)[-1]
|
||||
else:
|
||||
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
@@ -2543,6 +2654,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
safetensors_from_pt = False
|
||||
if filename == SAFE_WEIGHTS_NAME:
|
||||
with safe_open(resolved_archive_file, framework="tf") as f:
|
||||
safetensors_metadata = f.metadata()
|
||||
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||
" Make sure you save your model with the `save_pretrained` method."
|
||||
)
|
||||
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
|
||||
@@ -2560,6 +2682,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return load_pytorch_checkpoint_in_tf2_model(
|
||||
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
)
|
||||
elif safetensors_from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
)
|
||||
|
||||
# we might need to extend the variable scope for composite models
|
||||
if load_weight_prefix is not None:
|
||||
|
||||
@@ -49,6 +49,7 @@ from .generic import (
|
||||
is_torch_tensor,
|
||||
reshape,
|
||||
squeeze,
|
||||
tensor_size,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
transpose,
|
||||
|
||||
@@ -445,3 +445,19 @@ def expand_dims(array, axis):
|
||||
return jnp.expand_dims(array, axis=axis)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||||
|
||||
|
||||
def tensor_size(array):
|
||||
"""
|
||||
Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
|
||||
"""
|
||||
if is_numpy_array(array):
|
||||
return np.size(array)
|
||||
elif is_torch_tensor(array):
|
||||
return array.numel()
|
||||
elif is_tf_tensor(array):
|
||||
return tf.size(array)
|
||||
elif is_jax_tensor(array):
|
||||
return array.size
|
||||
else:
|
||||
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||||
|
||||
@@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
@@ -94,12 +95,7 @@ if is_tf_available():
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import (
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
tf_shard_checkpoint,
|
||||
unpack_inputs,
|
||||
)
|
||||
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
@@ -119,6 +115,8 @@ if is_tf_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
@@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_special_layer_name_shardind(self):
|
||||
def test_special_layer_name_sharding(self):
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
@@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
||||
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No tf_model.h5 file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Can load from the TF-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user