Default to msgpack for safetensors (#27460)
* Default to msgpack for safetensors * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,16 @@ import numpy as np
|
||||
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import BertConfig, BertModel, is_flax_available
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
|
||||
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
|
||||
@@ -202,6 +210,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
@@ -218,21 +227,114 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
# Can load from the Flax-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_flax_and_pt(self):
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
|
||||
"""
|
||||
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
|
||||
saved in the "pt" format if torch isn't installed.
|
||||
"""
|
||||
if is_torch_available():
|
||||
# This test verifies that a correct error message is shown when loading from a pt safetensors
|
||||
# PyTorch shouldn't be installed for this to work correctly.
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
|
||||
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
_ = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download msgpack weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download msgpack weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||
FlaxBertModel.from_pretrained(location)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a local checkpoint that only has those
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||
TFBertModel.from_pretrained(location)
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user