Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -39,6 +39,7 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
@@ -496,6 +497,44 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_tf_cross_test
|
||||
def test_safetensors_tf_from_torch(self):
|
||||
hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(hub_model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
|
||||
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
TFBertModel.from_pretrained(path)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user