Safetensors serialization by default (#27064)

* Safetensors serialization by default

* First pass on the tests

* Second pass on the tests

* Third pass on the tests

* Fix TF weight loading from TF-format safetensors

* Specific encoder-decoder fixes for weight crossloading

* Add VisionEncoderDecoder fixes for TF too

* Change filename test for pt-to-tf

* One missing fix for TFVisionEncoderDecoder

* Fix the other crossload test

* Support for flax + updated tests

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Sanchit's comments

* Sanchit's comments 2

* Nico's comments

* Fix tests

* cleanup

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2023-10-31 19:16:49 +01:00
committed by GitHub
parent 25e6e9418c
commit 113ebf80ac
20 changed files with 433 additions and 137 deletions

View File

@@ -24,7 +24,7 @@ import tempfile
import unittest
import unittest.mock as mock
from huggingface_hub import HfFolder, Repository, delete_repo
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError
@@ -39,6 +39,7 @@ from transformers.testing_utils import ( # noqa: F401
is_staging_test,
require_safetensors,
require_tf,
require_torch,
slow,
)
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
@@ -496,6 +497,44 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_tf_from_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = TFBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
@is_pt_tf_cross_test
def test_safetensors_tf_from_torch(self):
hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
new_model = TFBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(hub_model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
# This should not raise even if there are two types of sharded weights
TFBertModel.from_pretrained(path)
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
@require_tf
@is_staging_test