From 9dc8fe1b325f270320cdf205778bddae03c6ba1f Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 13 Nov 2023 15:17:01 +0100 Subject: [PATCH] Default to msgpack for safetensors (#27460) * Default to msgpack for safetensors * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../modeling_flax_pytorch_utils.py | 6 +- src/transformers/modeling_flax_utils.py | 41 +++---- tests/test_modeling_flax_utils.py | 112 +++++++++++++++++- tests/test_modeling_tf_utils.py | 65 ++++++++++ 4 files changed, 196 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 5a0f52a995..f78c4e78c7 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -50,7 +50,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( """Load pytorch checkpoints in a flax model""" try: import torch # noqa: F401 - except ImportError: + except (ImportError, ModuleNotFoundError): logger.error( "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" @@ -150,7 +150,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision try: import torch # noqa: F401 - except ImportError: + except (ImportError, ModuleNotFoundError): logger.error( "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" @@ -349,7 +349,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): try: import torch # noqa: F401 - except ImportError: + except (ImportError, ModuleNotFoundError): logger.error( "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 9e63cb0cb9..37567d3d84 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -721,7 +721,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): - if is_safetensors_available() and os.path.isfile( + if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): + # Load from a sharded Flax checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) + is_sharded = True + elif is_safetensors_available() and os.path.isfile( os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) ): # Load from a safetensors checkpoint @@ -735,13 +742,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): # Load from a sharded pytorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) is_sharded = True - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): - # Load from a Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) - elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): - # Load from a sharded Flax checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) - is_sharded = True # At this stage we don't have a weight file so we will raise an error. elif is_safetensors_available() and os.path.isfile( os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) @@ -770,8 +770,6 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): else: if from_pt: filename = WEIGHTS_NAME - elif is_safetensors_available(): - filename = SAFE_WEIGHTS_NAME else: filename = FLAX_WEIGHTS_NAME @@ -792,22 +790,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): } resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: - # Did not find the safetensors file, let's fallback to Flax. - # No support for sharded safetensors yet, so we'll raise an error if that's all we find. - filename = FLAX_WEIGHTS_NAME - resolved_archive_file = cached_file( - pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **cached_file_kwargs - ) + # Maybe the checkpoint is sharded, we try to grab the index name in this case. if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs ) if resolved_archive_file is not None: is_sharded = True + # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. if resolved_archive_file is None and from_pt: resolved_archive_file = cached_file( @@ -815,6 +805,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ) if resolved_archive_file is not None: is_sharded = True + + # If we still haven't found anything, look for `safetensors`. + if resolved_archive_file is None: + # No support for sharded safetensors yet, so we'll raise an error if that's all we find. + filename = SAFE_WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs + ) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. if resolved_archive_file is None: # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error # message. diff --git a/tests/test_modeling_flax_utils.py b/tests/test_modeling_flax_utils.py index e0e6c873c6..e668b43532 100644 --- a/tests/test_modeling_flax_utils.py +++ b/tests/test_modeling_flax_utils.py @@ -19,8 +19,16 @@ import numpy as np from huggingface_hub import HfFolder, delete_repo, snapshot_download from requests.exceptions import HTTPError -from transformers import BertConfig, BertModel, is_flax_available -from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch +from transformers import BertConfig, BertModel, is_flax_available, is_torch_available +from transformers.testing_utils import ( + TOKEN, + USER, + is_pt_flax_cross_test, + is_staging_test, + require_flax, + require_safetensors, + require_torch, +) from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME @@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase): @require_flax @require_torch + @is_pt_flax_cross_test def test_safetensors_save_and_load_pt_to_flax(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True) pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") @@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase): @require_safetensors def test_safetensors_load_from_hub(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub + """ flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") # Can load from the Flax-formatted checkpoint safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only") self.assertTrue(check_models_equal(flax_model, safetensors_model)) + @require_safetensors + def test_safetensors_load_from_local(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp) + flax_model = FlaxBertModel.from_pretrained(location) + + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp) + safetensors_model = FlaxBertModel.from_pretrained(location) + + self.assertTrue(check_models_equal(flax_model, safetensors_model)) + @require_torch @require_safetensors - def test_safetensors_load_from_hub_flax_and_pt(self): - flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + @is_pt_flax_cross_test + def test_safetensors_load_from_hub_from_safetensors_pt(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub. + saved in the "pt" format. + """ + flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack") # Can load from the PyTorch-formatted checkpoint - safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True) + safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") self.assertTrue(check_models_equal(flax_model, safetensors_model)) + @require_torch + @require_safetensors + @is_pt_flax_cross_test + def test_safetensors_load_from_local_from_safetensors_pt(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub. + saved in the "pt" format. + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp) + flax_model = FlaxBertModel.from_pretrained(location) + + # Can load from the PyTorch-formatted checkpoint + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) + safetensors_model = FlaxBertModel.from_pretrained(location) + + self.assertTrue(check_models_equal(flax_model, safetensors_model)) + + @require_safetensors + def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self): + """ + This test checks that we cannot load safetensors from a checkpoint that only has safetensors + saved in the "pt" format if torch isn't installed. + """ + if is_torch_available(): + # This test verifies that a correct error message is shown when loading from a pt safetensors + # PyTorch shouldn't be installed for this to work correctly. + return + + # Cannot load from the PyTorch-formatted checkpoint without PyTorch installed + with self.assertRaises(ModuleNotFoundError): + _ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") + + @require_safetensors + def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self): + """ + This test checks that we cannot load safetensors from a checkpoint that only has safetensors + saved in the "pt" format if torch isn't installed. + """ + if is_torch_available(): + # This test verifies that a correct error message is shown when loading from a pt safetensors + # PyTorch shouldn't be installed for this to work correctly. + return + + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) + + # Cannot load from the PyTorch-formatted checkpoint without PyTorch installed + with self.assertRaises(ModuleNotFoundError): + _ = FlaxBertModel.from_pretrained(location) + + @require_safetensors + def test_safetensors_load_from_hub_msgpack_before_safetensors(self): + """ + This test checks that we'll first download msgpack weights before safetensors + The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch + """ + FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack") + + @require_safetensors + def test_safetensors_load_from_local_msgpack_before_safetensors(self): + """ + This test checks that we'll first download msgpack weights before safetensors + The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp) + FlaxBertModel.from_pretrained(location) + @require_safetensors def test_safetensors_flax_from_flax(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") diff --git a/tests/test_modeling_tf_utils.py b/tests/test_modeling_tf_utils.py index 130f920f71..ccc3f1f5ce 100644 --- a/tests/test_modeling_tf_utils.py +++ b/tests/test_modeling_tf_utils.py @@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase): # This should discard the safetensors weights in favor of the .h5 sharded weights TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") + @require_safetensors + def test_safetensors_load_from_local(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp) + tf_model = TFBertModel.from_pretrained(location) + + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp) + safetensors_model = TFBertModel.from_pretrained(location) + + for p1, p2 in zip(tf_model.weights, safetensors_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_load_from_hub_from_safetensors_pt(self): + """ + This test checks that we can load safetensors from a checkpoint that only has those on the Hub. + saved in the "pt" format. + """ + tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5") + + # Can load from the PyTorch-formatted checkpoint + safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") + for p1, p2 in zip(tf_model.weights, safetensors_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_load_from_local_from_safetensors_pt(self): + """ + This test checks that we can load safetensors from a local checkpoint that only has those + saved in the "pt" format. + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp) + tf_model = TFBertModel.from_pretrained(location) + + # Can load from the PyTorch-formatted checkpoint + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) + safetensors_model = TFBertModel.from_pretrained(location) + + for p1, p2 in zip(tf_model.weights, safetensors_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_load_from_hub_h5_before_safetensors(self): + """ + This test checks that we'll first download h5 weights before safetensors + The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch + """ + TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack") + + @require_safetensors + def test_safetensors_load_from_local_h5_before_safetensors(self): + """ + This test checks that we'll first download h5 weights before safetensors + The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch + """ + with tempfile.TemporaryDirectory() as tmp: + location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp) + TFBertModel.from_pretrained(location) + @require_tf @is_staging_test