Default to msgpack for safetensors (#27460)
* Default to msgpack for safetensors * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -535,6 +535,71 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a local checkpoint that only has those
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||
TFBertModel.from_pretrained(location)
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user