Default to msgpack for safetensors (#27460)

* Default to msgpack for safetensors

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-11-13 15:17:01 +01:00
committed by Lysandre
parent f1185a4a73
commit 757171dfcf
4 changed files with 196 additions and 28 deletions

View File

@@ -19,8 +19,16 @@ import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import (
TOKEN,
USER,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
require_safetensors,
require_torch,
)
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_flax
@require_torch
@is_pt_flax_cross_test
def test_safetensors_save_and_load_pt_to_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors
def test_safetensors_load_from_hub(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
# Can load from the Flax-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
def test_safetensors_load_from_hub_flax_and_pt(self):
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
@is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
"""
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
"""
This test checks that we'll first download msgpack weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
FlaxBertModel.from_pretrained(location)
@require_safetensors
def test_safetensors_flax_from_flax(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")

View File

@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
@require_safetensors
def test_safetensors_load_from_local(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format.
"""
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
# Can load from the PyTorch-formatted checkpoint
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt(self):
"""
This test checks that we can load safetensors from a local checkpoint that only has those
saved in the "pt" format.
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
tf_model = TFBertModel.from_pretrained(location)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = TFBertModel.from_pretrained(location)
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
@require_safetensors
def test_safetensors_load_from_local_h5_before_safetensors(self):
"""
This test checks that we'll first download h5 weights before safetensors
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
"""
with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
TFBertModel.from_pretrained(location)
@require_tf
@is_staging_test