From 6c24443ff510238196b8e3139b69c5e9bdaf4e2f Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:56:29 -0400 Subject: [PATCH] Safetensors tf (#19900) * Wip * Add safetensors support for TensorFlow * First tests * Add final test for now * Retrigger CI like this * Update src/transformers/modeling_tf_utils.py Co-authored-by: Lysandre Debut Co-authored-by: Lysandre Debut --- src/transformers/modeling_tf_pytorch_utils.py | 4 +- src/transformers/modeling_tf_utils.py | 140 +++++++++++++++++- src/transformers/utils/__init__.py | 1 + src/transformers/utils/generic.py | 16 ++ tests/test_modeling_tf_common.py | 62 +++++++- 5 files changed, 208 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 864e34b016..3f2b564b70 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -21,7 +21,7 @@ import re import numpy -from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze +from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size from .utils import transpose as transpose_func @@ -273,7 +273,7 @@ def load_pytorch_state_dict_in_tf2_model( array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape) - tf_loaded_numel += array.size + tf_loaded_numel += tensor_size(array) weight_value_tuples.append((symbolic_weight, array)) all_pytorch_weights.discard(name) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 837e7a12bf..ac2b48e8c4 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,8 @@ from .generation_tf_utils import TFGenerationMixin from .tf_utils import shape_list from .utils import ( DUMMY_INPUTS, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -59,12 +61,18 @@ from .utils import ( has_file, is_offline_mode, is_remote_url, + is_safetensors_available, logging, requires_backends, working_or_temp_dir, ) +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.tensorflow import load_file as safe_load_file + from safetensors.tensorflow import save_file as safe_save_file + if TYPE_CHECKING: from . import PreTrainedTokenizerBase @@ -612,6 +620,14 @@ def dtype_byte_size(dtype): return bit_size // 8 +def format_weight_name(name, _prefix=None): + if "model." not in name and len(name.split("/")) > 1: + name = "/".join(name.split("/")[1:]) + if _prefix is not None: + name = _prefix + "/" + name + return name + + def tf_shard_checkpoint(weights, max_shard_size="10GB"): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -849,6 +865,17 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the mismatched layers. """ + if resolved_archive_file.endswith(".safetensors"): + load_function = load_tf_weights_from_safetensors + else: + load_function = load_tf_weights_from_h5 + + return load_function( + model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix + ) + + +def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): missing_layers = [] unexpected_layers = [] mismatched_layers = [] @@ -952,6 +979,47 @@ def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, return missing_layers, unexpected_layers, mismatched_layers +def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): + # Read the safetensors file + state_dict = safe_load_file(resolved_archive_file) + + weight_value_tuples = [] + mismatched_layers = [] + + weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] + loaded_weight_names = list(state_dict.keys()) + + # Find the missing layers from the high level list of layers + missing_layers = list(set(weight_names) - set(loaded_weight_names)) + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) + + weight_value_tuples = [] + for weight in model.weights: + weight_name = format_weight_name(weight.name, _prefix=_prefix) + if weight_name in state_dict: + weight_value = state_dict[weight_name] + # Check if the shape of the current weight and the one from the H5 file are different + if K.int_shape(weight) != weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue + try: + weight_value = tf.reshape(weight_value, K.int_shape(weight)) + except ValueError as e: + if ignore_mismatched_sizes: + mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) + continue + else: + raise e + + weight_value_tuples.append((weight, weight_value)) + + # Load all the weights + K.batch_set_value(weight_value_tuples) + + return missing_layers, unexpected_layers, mismatched_layers + + def init_copy_embeddings(old_embeddings, new_num_tokens): r""" This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case @@ -2118,6 +2186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu signatures=None, max_shard_size: Union[int, str] = "10GB", create_pr: bool = False, + safe_serialization: bool = False, **kwargs ): """ @@ -2152,6 +2221,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu create_pr (`bool`, *optional*, defaults to `False`): Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. @@ -2186,7 +2257,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu self.config.save_pretrained(save_directory) # If we save using the predefined names, we can load using `from_pretrained` - output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) shards, index = tf_shard_checkpoint(self.weights, max_shard_size) @@ -2195,15 +2267,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu full_filename = os.path.join(save_directory, filename) # If we have a shard file that is not going to be replaced, we delete it, but only from the main process # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") if ( - filename.startswith(TF2_WEIGHTS_NAME[:-4]) + filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards.keys() ): os.remove(full_filename) if index is None: - self.save_weights(output_model_file) + if safe_serialization: + state_dict = {format_weight_name(w.name): w.value() for w in self.weights} + safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) + else: + self.save_weights(output_model_file) logger.info(f"Model weights saved in {output_model_file}") else: save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME) @@ -2427,6 +2504,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Load from a sharded PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) is_sharded = True + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) @@ -2457,7 +2546,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu resolved_archive_file = download_url(pretrained_model_name_or_path) else: # set correct filename - filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME + if from_pt: + filename = WEIGHTS_NAME + elif is_safetensors_available(): + filename = SAFE_WEIGHTS_NAME + else: + filename = TF2_WEIGHTS_NAME try: # Load from URL or cache if already cached @@ -2476,8 +2570,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + raise NotImplementedError( + "Support for sharded checkpoints using safetensors is coming soon!" + ) + else: + # This repo has no safetensors file of any kind, we switch to TensorFlow. + filename = TF2_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs + ) if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( @@ -2521,6 +2631,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu if is_local: logger.info(f"loading weights file {archive_file}") resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") else: @@ -2543,6 +2654,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu _commit_hash=commit_hash, ) + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + config.name_or_path = pretrained_model_name_or_path # composed models, *e.g.* TFRag, require special treatment when it comes to loading @@ -2560,6 +2682,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu return load_pytorch_checkpoint_in_tf2_model( model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info ) + elif safetensors_from_pt: + from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model + + state_dict = safe_load_file(resolved_archive_file) + # Load from a PyTorch checkpoint + return load_pytorch_state_dict_in_tf2_model( + model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info + ) # we might need to extend the variable scope for composite models if load_weight_prefix is not None: diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2dbca85df0..7ea8cc5585 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -49,6 +49,7 @@ from .generic import ( is_torch_tensor, reshape, squeeze, + tensor_size, to_numpy, to_py_obj, transpose, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 334141bd55..47619d2794 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -445,3 +445,19 @@ def expand_dims(array, axis): return jnp.expand_dims(array, axis=axis) else: raise ValueError(f"Type not supported for expand_dims: {type(array)}.") + + +def tensor_size(array): + """ + Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays. + """ + if is_numpy_array(array): + return np.size(array) + elif is_torch_tensor(array): + return array.numel() + elif is_tf_tensor(array): + return tf.size(array) + elif is_jax_tensor(array): + return array.size + else: + raise ValueError(f"Type not supported for expand_dims: {type(array)}.") diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index d137ae0faf..9e9e4d9930 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -43,13 +43,14 @@ from transformers.testing_utils import ( # noqa: F401 _tf_gpu_memory_limit, is_pt_tf_cross_test, is_staging_test, + require_safetensors, require_tf, require_tf2onnx, slow, tooslow, torch_device, ) -from transformers.utils import logging +from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from transformers.utils.generic import ModelOutput @@ -94,12 +95,7 @@ if is_tf_available(): TFSampleDecoderOnlyOutput, TFSampleEncoderDecoderOutput, ) - from transformers.modeling_tf_utils import ( - TF2_WEIGHTS_INDEX_NAME, - TF2_WEIGHTS_NAME, - tf_shard_checkpoint, - unpack_inputs, - ) + from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs from transformers.tf_utils import stable_softmax if _tf_gpu_memory_limit is not None: @@ -119,6 +115,8 @@ if is_tf_available(): if is_torch_available(): import torch + from transformers import BertModel + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) @@ -2168,7 +2166,7 @@ class UtilsFunctionsTest(unittest.TestCase): ) @slow - def test_special_layer_name_shardind(self): + def test_special_layer_name_sharding(self): retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever) @@ -2268,6 +2266,54 @@ class UtilsFunctionsTest(unittest.TestCase): self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys())) self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys())) + @require_safetensors + def test_safetensors_save_and_load(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + # No tf_model.h5 file, only a model.safetensors + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) + + new_model = TFBertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @is_pt_tf_cross_test + def test_safetensors_save_and_load_pt_to_tf(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + pt_model.save_pretrained(tmp_dir, safe_serialization=True) + # Check we have a model.safetensors file + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + + new_model = TFBertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_load_from_hub(self): + tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Can load from the TF-formatted checkpoint + safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf") + + # Check models are equal + for p1, p2 in zip(safetensors_model.weights, tf_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + # Can load from the PyTorch-formatted checkpoint + safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors") + + # Check models are equal + for p1, p2 in zip(safetensors_model.weights, tf_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @require_tf @is_staging_test