From 113ebf80ac9bdb74037239847cd906d7ea986a18 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 31 Oct 2023 19:16:49 +0100 Subject: [PATCH] Safetensors serialization by default (#27064) * Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../modeling_flax_pytorch_utils.py | 28 +++- src/transformers/modeling_flax_utils.py | 126 ++++++++++++++---- src/transformers/modeling_tf_utils.py | 57 ++++---- src/transformers/modeling_utils.py | 34 +++-- .../modeling_encoder_decoder.py | 4 +- .../modeling_tf_encoder_decoder.py | 22 +-- .../modeling_tf_vision_encoder_decoder.py | 22 +-- .../modeling_vision_encoder_decoder.py | 4 +- src/transformers/pipelines/base.py | 4 +- src/transformers/training_args.py | 4 +- src/transformers/utils/hub.py | 4 +- tests/models/auto/test_modeling_tf_auto.py | 5 + .../test_modeling_tf_encoder_decoder.py | 13 +- ...test_modeling_tf_vision_encoder_decoder.py | 10 +- tests/test_modeling_common.py | 5 +- tests/test_modeling_flax_utils.py | 92 ++++++++++++- tests/test_modeling_tf_utils.py | 41 +++++- tests/test_modeling_utils.py | 88 +++++++++--- tests/trainer/test_trainer.py | 6 +- tests/utils/test_cli.py | 1 - 20 files changed, 433 insertions(+), 137 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 79d91da497..5a0f52a995 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,9 +27,15 @@ from flax.traverse_util import flatten_dict, unflatten_dict import transformers +from . import is_safetensors_available from .utils import logging +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.flax import load_file as safe_load_file + + logger = logging.get_logger(__name__) @@ -56,7 +62,13 @@ def load_pytorch_checkpoint_in_flax_state_dict( pt_path = os.path.abspath(pytorch_checkpoint_path) logger.info(f"Loading PyTorch weights from {pt_path}") - pt_state_dict = torch.load(pt_path, map_location="cpu") + if pt_path.endswith(".safetensors"): + pt_state_dict = {} + with safe_open(pt_path, framework="pt") as f: + for k in f.keys(): + pt_state_dict[k] = f.get_tensor(k) + else: + pt_state_dict = torch.load(pt_path, map_location="cpu") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) @@ -319,11 +331,15 @@ def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path): flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) # load flax weight dict - with open(flax_checkpoint_path, "rb") as state_f: - try: - flax_state_dict = from_bytes(flax_cls, state_f.read()) - except UnpicklingError: - raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") + if flax_checkpoint_path.endswith(".safetensors"): + flax_state_dict = safe_load_file(flax_checkpoint_path) + flax_state_dict = unflatten_dict(flax_state_dict, sep=".") + else: + with open(flax_checkpoint_path, "rb") as state_f: + try: + flax_state_dict = from_bytes(flax_cls, state_f.read()) + except UnpicklingError: + raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") return load_flax_weights_in_pytorch_model(model, flax_state_dict) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index b05fa4d72a..9e63cb0cb9 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -39,6 +39,8 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d from .utils import ( FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushToHubMixin, @@ -54,8 +56,14 @@ from .utils import ( replace_return_docstrings, ) from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.import_utils import is_safetensors_available +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.flax import load_file as safe_load_file + from safetensors.flax import save_file as safe_save_file + logger = logging.get_logger(__name__) @@ -422,6 +430,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ```""" return self._cast_floating_to(params, jnp.float16, mask) + @classmethod + def load_flax_weights(cls, resolved_archive_file): + try: + if resolved_archive_file.endswith(".safetensors"): + state = safe_load_file(resolved_archive_file) + state = unflatten_dict(state, sep=".") + else: + with open(resolved_archive_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(resolved_archive_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") + + return state + @classmethod def load_flax_sharded_weights(cls, shard_files): """ @@ -688,7 +721,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): - if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): + if is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) elif from_pt and os.path.isfile( @@ -705,6 +743,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) is_sharded = True # At this stage we don't have a weight file so we will raise an error. + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " @@ -723,7 +768,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: - filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME + if from_pt: + filename = WEIGHTS_NAME + elif is_safetensors_available(): + filename = SAFE_WEIGHTS_NAME + else: + filename = FLAX_WEIGHTS_NAME + try: # Load from URL or cache if already cached cached_file_kwargs = { @@ -741,8 +792,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): } resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: + # Did not find the safetensors file, let's fallback to Flax. + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = FLAX_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs + ) if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( @@ -751,21 +809,26 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): if resolved_archive_file is not None: is_sharded = True # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. - elif resolved_archive_file is None and from_pt: + if resolved_archive_file is None and from_pt: resolved_archive_file = cached_file( pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs ) if resolved_archive_file is not None: is_sharded = True if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error # message. has_file_kwargs = { "revision": revision, "proxies": proxies, "token": token, } - if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + raise NotImplementedError( + "Support for sharded checkpoints using safetensors is coming soon!" + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" @@ -798,6 +861,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): if is_local: logger.info(f"loading weights file {archive_file}") resolved_archive_file = archive_file + filename = resolved_archive_file.split(os.path.sep)[-1] else: logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") else: @@ -821,31 +885,27 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): _commit_hash=commit_hash, ) + safetensors_from_pt = False + if filename == SAFE_WEIGHTS_NAME: + with safe_open(resolved_archive_file, framework="flax") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + # init random models model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) - if from_pt: + if from_pt or safetensors_from_pt: state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) else: if is_sharded: state = cls.load_flax_sharded_weights(resolved_archive_file) else: - try: - with open(resolved_archive_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: - try: - with open(resolved_archive_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") + state = cls.load_flax_weights(resolved_archive_file) # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 @@ -1030,6 +1090,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): push_to_hub=False, max_shard_size="10GB", token: Optional[Union[str, bool]] = None, + safe_serialization: bool = False, **kwargs, ): """ @@ -1059,6 +1120,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). kwargs (`Dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or through msgpack. """ use_auth_token = kwargs.pop("use_auth_token", None) @@ -1103,24 +1166,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): self.generation_config.save_pretrained(save_directory) # save model - output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME + output_model_file = os.path.join(save_directory, weights_name) shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) # Clean the folder from a previous save for filename in os.listdir(save_directory): full_filename = os.path.join(save_directory, filename) + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") if ( - filename.startswith(FLAX_WEIGHTS_NAME[:-4]) + filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards.keys() ): os.remove(full_filename) if index is None: - with open(output_model_file, "wb") as f: + if safe_serialization: params = params if params is not None else self.params - model_bytes = to_bytes(params) - f.write(model_bytes) + flat_dict = flatten_dict(params, sep=".") + safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) + else: + with open(output_model_file, "wb") as f: + params = params if params is not None else self.params + model_bytes = to_bytes(params) + f.write(model_bytes) else: save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index bea3edfa22..c342b5059c 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -626,11 +626,13 @@ def dtype_byte_size(dtype): return bit_size // 8 -def format_weight_name(name, _prefix=None): +def strip_model_name_and_prefix(name, _prefix=None): + if _prefix is not None and name.startswith(_prefix): + name = name[len(_prefix) :] + if name.startswith("/"): + name = name[1:] if "model." not in name and len(name.split("/")) > 1: name = "/".join(name.split("/")[1:]) - if _prefix is not None: - name = _prefix + "/" + name return name @@ -986,7 +988,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat # Read the safetensors file with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: mismatched_layers = [] - weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights] + weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights] loaded_weight_names = list(safetensors_archive.keys()) # Find the missing layers from the high level list of layers missing_layers = list(set(weight_names) - set(loaded_weight_names)) @@ -994,7 +996,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat unexpected_layers = list(set(loaded_weight_names) - set(weight_names)) for weight in model.weights: - weight_name = format_weight_name(weight.name, _prefix=_prefix) + weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix) if weight_name in loaded_weight_names: weight_value = safetensors_archive.get_tensor(weight_name) # Check if the shape of the current weight and the one from the H5 file are different @@ -1003,7 +1005,7 @@ def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismat # If the two shapes are not compatible we raise an issue try: weight_value = tf.reshape(weight_value, K.int_shape(weight)) - except ValueError as e: + except (ValueError, tf.errors.InvalidArgumentError) as e: if ignore_mismatched_sizes: mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight))) continue @@ -2367,7 +2369,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu create_pr (`bool`, *optional*, defaults to `False`): Whether or not to create a PR with the uploaded files or directly commit. safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`). token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -2457,7 +2459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu if index is None: if safe_serialization: - state_dict = {format_weight_name(w.name): w.value() for w in self.weights} + state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights} safe_save_file(state_dict, output_model_file, metadata={"format": "tf"}) else: self.save_weights(output_model_file) @@ -2718,13 +2720,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ): # Load from a safetensors checkpoint archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) - elif is_safetensors_available() and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - is_sharded = True - raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) @@ -2732,6 +2727,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Load from a sharded TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) is_sharded = True + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True + raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") # At this stage we don't have a weight file so we will raise an error. elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) @@ -2784,21 +2786,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. + # Did not find the safetensors file, let's fallback to TF. + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = TF2_WEIGHTS_NAME resolved_archive_file = cached_file( - pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs + pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs ) - if resolved_archive_file is not None: - is_sharded = True - raise NotImplementedError( - "Support for sharded checkpoints using safetensors is coming soon!" - ) - else: - # This repo has no safetensors file of any kind, we switch to TensorFlow. - filename = TF2_WEIGHTS_NAME - resolved_archive_file = cached_file( - pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs - ) if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME: # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( @@ -2821,7 +2814,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu "proxies": proxies, "token": token, } - if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): + if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): + is_sharded = True + raise NotImplementedError( + "Support for sharded checkpoints using safetensors is coming soon!" + ) + elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" @@ -2928,6 +2926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu output_loading_info=output_loading_info, _prefix=load_weight_prefix, ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, ) # 'by_name' allow us to do transfer learning by skipping/adding layers diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7d02d53fc3..ccb9073aef 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -470,10 +470,6 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " "you save your model with the `save_pretrained` method." ) - elif metadata["format"] != "pt": - raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) return safe_load_file(checkpoint_file) try: if ( @@ -1934,7 +1930,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix save_function: Callable = torch.save, push_to_hub: bool = False, max_shard_size: Union[int, str] = "5GB", - safe_serialization: bool = False, + safe_serialization: bool = True, variant: Optional[str] = None, token: Optional[Union[str, bool]] = None, save_peft_format: bool = True, @@ -1975,7 +1971,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. @@ -2736,8 +2732,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " sure the weights are in PyTorch format." ) - from_pt = not (from_tf | from_flax) - user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline @@ -3103,6 +3097,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _commit_hash=commit_hash, ) + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() + + if metadata.get("format") == "pt": + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) + # load pt weights early so that we know which dtype to init the model under if from_pt: if not is_sharded and state_dict is None: @@ -3391,7 +3408,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( model, missing_keys, diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index ff5a56749f..27a213707c 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -366,8 +366,8 @@ class EncoderDecoderModel(PreTrainedModel): model.config = config if hasattr(model, "enc_to_dec_proj"): - model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight - model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias + model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() + model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() return model diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 19fc47546b..14653410b0 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -306,17 +306,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - if kwargs.get("from_pt", False): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - encoder_model_type = config.encoder.model_type + # This override is only needed in the case where we're crossloading weights from PT. However, since weights are + # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. + # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it + # or not. - def tf_to_pt_weight_rename(tf_weight): - if "encoder" in tf_weight and "decoder" not in tf_weight: - return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) - else: - return tf_weight + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + encoder_model_type = config.encoder.model_type - kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename + def tf_to_pt_weight_rename(tf_weight): + if "encoder" in tf_weight and "decoder" not in tf_weight: + return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) + else: + return tf_weight + + kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index a0fae071a1..dea1aaaf59 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -322,17 +322,21 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! - if kwargs.get("from_pt", False): - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - encoder_model_type = config.encoder.model_type + # This override is only needed in the case where we're crossloading weights from PT. However, since weights are + # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. + # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it + # or not. - def tf_to_pt_weight_rename(tf_weight): - if "encoder" in tf_weight and "decoder" not in tf_weight: - return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) - else: - return tf_weight + config = AutoConfig.from_pretrained(pretrained_model_name_or_path) + encoder_model_type = config.encoder.model_type - kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename + def tf_to_pt_weight_rename(tf_weight): + if "encoder" in tf_weight and "decoder" not in tf_weight: + return re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight) + else: + return tf_weight + + kwargs["tf_to_pt_weight_rename"] = tf_to_pt_weight_rename return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @classmethod diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 60646809a6..f9c6c25cd8 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -342,8 +342,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): model.config = config if hasattr(model, "enc_to_dec_proj"): - model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight - model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias + model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() + model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() return model diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 36c9585a69..2d18384d1b 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -836,7 +836,7 @@ class Pipeline(_ScikitCompat): # then we should keep working self.image_processor = self.feature_extractor - def save_pretrained(self, save_directory: str, safe_serialization: bool = False): + def save_pretrained(self, save_directory: str, safe_serialization: bool = True): """ Save the pipeline's model and tokenizer. @@ -844,7 +844,7 @@ class Pipeline(_ScikitCompat): save_directory (`str`): A path to the directory where to saved. It will be created if it doesn't exist. safe_serialization (`str`): - Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow + Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8a6d7255f5..147d1e6b1c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -293,7 +293,7 @@ class TrainingArguments: `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two checkpoints are saved: the last one and the best one (if they are different). - save_safetensors (`bool`, *optional*, defaults to `False`): + save_safetensors (`bool`, *optional*, defaults to `True`): Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of default `torch.load` and `torch.save`. save_on_each_node (`bool`, *optional*, defaults to `False`): @@ -797,7 +797,7 @@ class TrainingArguments: }, ) save_safetensors: Optional[bool] = field( - default=False, + default=True, metadata={ "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save." }, diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 2dcfd7f3c8..0d58211da8 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -797,7 +797,7 @@ class PushToHubMixin: token: Optional[Union[bool, str]] = None, max_shard_size: Optional[Union[int, str]] = "5GB", create_pr: bool = False, - safe_serialization: bool = False, + safe_serialization: bool = True, revision: str = None, commit_description: str = None, **deprecated_kwargs, @@ -827,7 +827,7 @@ class PushToHubMixin: Google Colab instances without any CPU OOM issues. create_pr (`bool`, *optional*, defaults to `False`): Whether or not to create a PR with the uploaded files or directly commit. - safe_serialization (`bool`, *optional*, defaults to `False`): + safe_serialization (`bool`, *optional*, defaults to `True`): Whether or not to convert the model weights in safetensors format for safer serialization. revision (`str`, *optional*): Branch to push the uploaded files to. diff --git a/tests/models/auto/test_modeling_tf_auto.py b/tests/models/auto/test_modeling_tf_auto.py index c8754ca427..2f6fe47615 100644 --- a/tests/models/auto/test_modeling_tf_auto.py +++ b/tests/models/auto/test_modeling_tf_auto.py @@ -211,6 +211,8 @@ class TFAutoModelTest(unittest.TestCase): config = copy.deepcopy(model.config) config.architectures = ["FunnelBaseModel"] model = TFAutoModel.from_config(config) + model.build() + self.assertIsInstance(model, TFFunnelBaseModel) with tempfile.TemporaryDirectory() as tmp_dir: @@ -245,7 +247,10 @@ class TFAutoModelTest(unittest.TestCase): # Now that the config is registered, it can be used as any other config with the auto-API tiny_config = BertModelTester(self).get_config() config = NewModelConfig(**tiny_config.to_dict()) + model = auto_class.from_config(config) + model.build() + self.assertIsInstance(model, TFNewModel) with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py index ab5da3d41e..1d8d4e985b 100644 --- a/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -525,7 +525,7 @@ class TFEncoderDecoderMixin: # PT -> TF with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) @@ -542,7 +542,7 @@ class TFEncoderDecoderMixin: with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + tf_model = TFEncoderDecoderModel.from_pretrained(tmpdirname) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) @@ -560,7 +560,8 @@ class TFEncoderDecoderMixin: tf_model(**tf_inputs_dict) with tempfile.TemporaryDirectory() as tmpdirname: - tf_model.save_pretrained(tmpdirname) + # TODO Matt: PT doesn't support loading TF safetensors - remove the arg and from_tf=True when it does + tf_model.save_pretrained(tmpdirname, safe_serialization=False) pt_model = EncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) @@ -1129,9 +1130,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dirname_1, tempfile.TemporaryDirectory() as tmp_dirname_2: encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) - encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained( - tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True - ) + encoder_decoder_tf = TFEncoderDecoderModel.from_encoder_decoder_pretrained(tmp_dirname_1, tmp_dirname_2) logits_tf = encoder_decoder_tf(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits @@ -1150,7 +1149,7 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase): # TensorFlow => PyTorch with tempfile.TemporaryDirectory() as tmp_dirname: - encoder_decoder_tf.save_pretrained(tmp_dirname) + encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False) encoder_decoder_pt = EncoderDecoderModel.from_pretrained(tmp_dirname, from_tf=True) max_diff = np.max(np.abs(logits_pt.detach().cpu().numpy() - logits_tf.numpy())) diff --git a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index e173e21a9b..2cb5e44672 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -458,7 +458,7 @@ class TFVisionEncoderDecoderMixin: # PT -> TF with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname) self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) @@ -473,7 +473,7 @@ class TFVisionEncoderDecoderMixin: with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True) + tf_model = TFVisionEncoderDecoderModel.from_pretrained(tmpdirname) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) @@ -489,7 +489,7 @@ class TFVisionEncoderDecoderMixin: tf_model(**tf_inputs_dict) with tempfile.TemporaryDirectory() as tmpdirname: - tf_model.save_pretrained(tmpdirname) + tf_model.save_pretrained(tmpdirname, safe_serialization=False) pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) @@ -803,7 +803,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase): encoder_decoder_pt.encoder.save_pretrained(tmp_dirname_1) encoder_decoder_pt.decoder.save_pretrained(tmp_dirname_2) encoder_decoder_tf = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained( - tmp_dirname_1, tmp_dirname_2, encoder_from_pt=True, decoder_from_pt=True + tmp_dirname_1, tmp_dirname_2 ) logits_tf = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits @@ -814,7 +814,7 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase): # Make sure `from_pretrained` following `save_pretrained` work and give the same result # (See https://github.com/huggingface/transformers/pull/14016) with tempfile.TemporaryDirectory() as tmp_dirname: - encoder_decoder_tf.save_pretrained(tmp_dirname) + encoder_decoder_tf.save_pretrained(tmp_dirname, safe_serialization=False) encoder_decoder_tf = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname) logits_tf_2 = encoder_decoder_tf(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids).logits diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 634d7631df..3c48100747 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -91,6 +91,7 @@ if is_accelerate_available(): if is_torch_available(): import torch + from safetensors.torch import save_file as safe_save_file from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding @@ -1751,8 +1752,8 @@ class ModelTesterMixin: # We are nuking ALL weights on file, so every parameter should # yell on load. We're going to detect if we yell too much, or too little. - with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f: - torch.save({}, f) + placeholder_dict = {"tensor": torch.tensor([1, 2])} + safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"}) model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True) prefix = f"{model_reloaded.base_model_prefix}." diff --git a/tests/test_modeling_flax_utils.py b/tests/test_modeling_flax_utils.py index d8fb71a610..06ed30f8af 100644 --- a/tests/test_modeling_flax_utils.py +++ b/tests/test_modeling_flax_utils.py @@ -16,11 +16,12 @@ import tempfile import unittest import numpy as np -from huggingface_hub import HfFolder, delete_repo +from huggingface_hub import HfFolder, delete_repo, snapshot_download from requests.exceptions import HTTPError -from transformers import BertConfig, is_flax_available -from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax +from transformers import BertConfig, BertModel, is_flax_available +from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch +from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME if is_flax_available(): @@ -184,3 +185,88 @@ class FlaxModelUtilsTest(unittest.TestCase): model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder) self.assertIsNotNone(model) + + @require_safetensors + def test_safetensors_save_and_load(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + + # No msgpack file, only a model.safetensors + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME))) + + new_model = FlaxBertModel.from_pretrained(tmp_dir) + + self.assertTrue(check_models_equal(model, new_model)) + + @require_flax + @require_torch + def test_safetensors_save_and_load_pt_to_flax(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True) + pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + pt_model.save_pretrained(tmp_dir) + + # Check we have a model.safetensors file + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + + new_model = FlaxBertModel.from_pretrained(tmp_dir) + + # Check models are equal + self.assertTrue(check_models_equal(model, new_model)) + + @require_safetensors + def test_safetensors_load_from_hub(self): + flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + # Can load from the Flax-formatted checkpoint + safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only") + self.assertTrue(check_models_equal(flax_model, safetensors_model)) + + @require_torch + @require_safetensors + def test_safetensors_load_from_hub_flax_and_pt(self): + flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + # Can load from the PyTorch-formatted checkpoint + safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True) + self.assertTrue(check_models_equal(flax_model, safetensors_model)) + + @require_safetensors + def test_safetensors_flax_from_flax(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = FlaxBertModel.from_pretrained(tmp_dir) + + self.assertTrue(check_models_equal(model, new_model)) + + @require_safetensors + @require_torch + def test_safetensors_flax_from_torch(self): + hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = FlaxBertModel.from_pretrained(tmp_dir) + + self.assertTrue(check_models_equal(hub_model, new_model)) + + @require_safetensors + def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self): + with tempfile.TemporaryDirectory() as tmp_dir: + path = snapshot_download( + "hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir + ) + + # This should not raise even if there are two types of sharded weights + FlaxBertModel.from_pretrained(path) + + @require_safetensors + def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self): + # This should not raise even if there are two types of sharded weights + # This should discard the safetensors weights in favor of the msgpack sharded weights + FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") diff --git a/tests/test_modeling_tf_utils.py b/tests/test_modeling_tf_utils.py index 862a2cffa8..6d0ed86407 100644 --- a/tests/test_modeling_tf_utils.py +++ b/tests/test_modeling_tf_utils.py @@ -24,7 +24,7 @@ import tempfile import unittest import unittest.mock as mock -from huggingface_hub import HfFolder, Repository, delete_repo +from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download from huggingface_hub.file_download import http_get from requests.exceptions import HTTPError @@ -39,6 +39,7 @@ from transformers.testing_utils import ( # noqa: F401 is_staging_test, require_safetensors, require_tf, + require_torch, slow, ) from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging @@ -496,6 +497,44 @@ class TFModelUtilsTest(unittest.TestCase): for p1, p2 in zip(safetensors_model.weights, tf_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @require_safetensors + def test_safetensors_tf_from_tf(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = TFBertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + @is_pt_tf_cross_test + def test_safetensors_tf_from_torch(self): + hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = TFBertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(hub_model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self): + with tempfile.TemporaryDirectory() as tmp_dir: + path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir) + + # This should not raise even if there are two types of sharded weights + TFBertModel.from_pretrained(path) + + @require_safetensors + def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self): + # This should not raise even if there are two types of sharded weights + # This should discard the safetensors weights in favor of the .h5 sharded weights + TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") + @require_tf @is_staging_test diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 9e824e8efa..8456871df6 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import glob import json import os @@ -42,7 +42,9 @@ from transformers.testing_utils import ( TestCasePlus, is_staging_test, require_accelerate, + require_flax, require_safetensors, + require_tf, require_torch, require_torch_accelerator, require_torch_multi_accelerator, @@ -56,7 +58,7 @@ from transformers.utils import ( WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) -from transformers.utils.import_utils import is_torchdynamo_available +from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torchdynamo_available sys.path.append(str(Path(__file__).parent.parent / "utils")) @@ -66,6 +68,7 @@ from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # if is_torch_available(): import torch + from safetensors.torch import save_file as safe_save_file from test_module.custom_modeling import CustomModel, NoSuperInitModel from torch import nn @@ -146,6 +149,13 @@ if is_torch_available(): self.decoder.weight = self.base.linear.weight +if is_flax_available(): + from transformers import FlaxBertModel + +if is_tf_available(): + from transformers import TFBertModel + + TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" @@ -420,13 +430,13 @@ class ModelUtilsTest(TestCasePlus): }, ) - def test_checkpoint_sharding_local(self): + def test_checkpoint_sharding_local_bin(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") with tempfile.TemporaryDirectory() as tmp_dir: # We use the same folder for various sizes to make sure a new save erases the old checkpoint. for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]: - model.save_pretrained(tmp_dir, max_shard_size=max_size) + model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False) # Get each shard file and its size shard_to_size = {} @@ -472,11 +482,11 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), ref_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) - def test_checkpoint_variant_local(self): + def test_checkpoint_variant_local_bin(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, variant="v2") + model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False) weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) @@ -492,11 +502,11 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.allclose(p1, p2)) - def test_checkpoint_variant_local_sharded(self): + def test_checkpoint_variant_local_sharded_bin(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB") + model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=False) weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"]) weights_index_file = os.path.join(tmp_dir, weights_index_name) @@ -604,18 +614,18 @@ class ModelUtilsTest(TestCasePlus): ) self.assertIsNotNone(model) - def test_checkpoint_variant_save_load(self): + def test_checkpoint_variant_save_load_bin(self): with tempfile.TemporaryDirectory() as tmp_dir: model = BertModel.from_pretrained( "hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2" ) weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) - model.save_pretrained(tmp_dir, variant="v2") + model.save_pretrained(tmp_dir, variant="v2", safe_serialization=False) # saving will create a variant checkpoint self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) - model.save_pretrained(tmp_dir) + model.save_pretrained(tmp_dir, safe_serialization=False) # saving shouldn't delete variant checkpoints weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"]) self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name))) @@ -874,7 +884,7 @@ class ModelUtilsTest(TestCasePlus): def test_base_model_to_head_model_load(self): base_model = BaseModel(PretrainedConfig()) with tempfile.TemporaryDirectory() as tmp_dir: - base_model.save_pretrained(tmp_dir) + base_model.save_pretrained(tmp_dir, safe_serialization=False) # Can load a base model in a model with head model = ModelWithHead.from_pretrained(tmp_dir) @@ -886,7 +896,7 @@ class ModelUtilsTest(TestCasePlus): head_state_dict = model.state_dict() base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"] base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"] - torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) + safe_save_file(base_state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) with self.assertRaisesRegex( ValueError, "The state dictionary of the model you are trying to load is corrupted." @@ -934,8 +944,8 @@ class ModelUtilsTest(TestCasePlus): # Loading the model with the same class, we do get a warning for unexpected weights state_dict = model.state_dict() - state_dict["added_key"] = state_dict["linear.weight"] - torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) + state_dict["added_key"] = copy.deepcopy(state_dict["linear.weight"]) + safe_save_file(state_dict, os.path.join(tmp_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) with CaptureLogger(logger) as cl: _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True) self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out) @@ -1072,6 +1082,54 @@ class ModelUtilsTest(TestCasePlus): ) self.assertEqual(model.generation_config.transformers_version, "foo") + @require_safetensors + def test_safetensors_torch_from_torch(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + @require_safetensors + @require_flax + def test_safetensors_torch_from_flax(self): + hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(hub_model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + @require_tf + @require_safetensors + def test_safetensors_torch_from_tf(self): + hub_model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(hub_model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + @require_safetensors + def test_safetensors_torch_from_torch_sharded(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB") + new_model = BertModel.from_pretrained(tmp_dir) + + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + @require_torch @is_staging_test diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6c208d0de0..ae6d8f7ae3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -403,7 +403,7 @@ if is_torch_available(): class TrainerIntegrationCommon: - def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=False): + def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True): weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"] if is_pretrained: @@ -415,7 +415,7 @@ class TrainerIntegrationCommon: self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename))) def check_best_model_has_been_loaded( - self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=False + self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True ): checkpoint = os.path.join(output_dir, f"checkpoint-{(total // freq) * freq}") log_history = TrainerState.load_from_json(os.path.join(checkpoint, "trainer_state.json")).log_history @@ -456,7 +456,7 @@ class TrainerIntegrationCommon: _ = log1.pop(key, None) self.assertEqual(log, log1) - def convert_to_sharded_checkpoint(self, folder, save_safe=False, load_safe=False): + def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True): # Converts a checkpoint of a regression model to a sharded checkpoint. if load_safe: loader = safetensors.torch.load_file diff --git a/tests/utils/test_cli.py b/tests/utils/test_cli.py index fc7b8ebb5e..b208ff19f1 100644 --- a/tests/utils/test_cli.py +++ b/tests/utils/test_cli.py @@ -43,7 +43,6 @@ class CLITest(unittest.TestCase): shutil.rmtree("/tmp/hf-internal-testing/tiny-random-gptj", ignore_errors=True) # cleans potential past runs transformers.commands.transformers_cli.main() - # The original repo has no TF weights -- if they exist, they were created by the CLI self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5")) @require_torch