Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -16,11 +16,12 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax
|
||||
from transformers import BertConfig, BertModel, is_flax_available
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax, require_safetensors, require_torch
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -184,3 +185,88 @@ class FlaxModelUtilsTest(unittest.TestCase):
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
|
||||
# No msgpack file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir)
|
||||
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
# Can load from the Flax-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_flax_and_pt(self):
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only", from_pt=True)
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(hub_model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = snapshot_download(
|
||||
"hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir
|
||||
)
|
||||
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
FlaxBertModel.from_pretrained(path)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||
|
||||
Reference in New Issue
Block a user