Support sharded safetensors in TF (#29350)
* Initial commit (still lots of unfinished bits) * (Still untested) add safetensors sharding to save_pretrained * Fix savetensors saving, update default shard size to match PT * Add proper loading of TF-format safetensors * Revert default size in case that changes things * Fix incorrect index name * Update loading priority * Update tests * Make the tests a little more stringent * Expand tests * Add sharded cross-test * Fix argument name * One more test fix * Adding mlx to the list of allowed formats * Remove irrelevant block for safetensors * Refactor warning logging into a separate function * Remove unused skip_logger_warnings arg * Update src/transformers/modeling_tf_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Move function def --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
@@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
def test_safetensors_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model.build_in_name_scope()
|
||||
new_model.build_in_name_scope()
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@slow
|
||||
def test_save_pretrained_signatures(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
@@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No tf_model.h5 file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_sharded_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# No tf weights or index file, only a safetensors index
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
@@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# Check we have a safetensors shard index file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
@@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
|
||||
# Confirm that we can access the TF weights too
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
|
||||
Reference in New Issue
Block a user