Poc to use safetensors (#19175)
* Poc to use safetensors * Typo * Final version * Add tests * Save with the right name! * Update tests/test_modeling_common.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Support for sharded checkpoints * Test from Hub part 1 * Test from hub part 2 * Fix regular checkpoint sharding * Bump for fixes Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
@@ -53,6 +53,7 @@ from transformers.testing_utils import (
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
@@ -61,6 +62,8 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
@@ -2980,6 +2983,57 @@ class ModelUtilsTest(TestCasePlus):
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||
)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No pytorch_model.bin file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load_sharded(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||
# No pytorch_model.bin index file, only a model.safetensors index
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
# No regular weights file
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_sharded(self):
|
||||
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
|
||||
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user