Support sharded safetensors in TF (#29350)

* Initial commit (still lots of unfinished bits)

* (Still untested) add safetensors sharding to save_pretrained

* Fix savetensors saving, update default shard size to match PT

* Add proper loading of TF-format safetensors

* Revert default size in case that changes things

* Fix incorrect index name

* Update loading priority

* Update tests

* Make the tests a little more stringent

* Expand tests

* Add sharded cross-test

* Fix argument name

* One more test fix

* Adding mlx to the list of allowed formats

* Remove irrelevant block for safetensors

* Refactor warning logging into a separate function

* Remove unused skip_logger_warnings arg

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Move function def

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2024-03-20 14:22:35 +00:00
committed by GitHub
parent 870bbb4c6b
commit 11ef35e828
3 changed files with 324 additions and 86 deletions

View File

@@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
require_torch,
slow,
)
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
logging,
)
logger = logging.get_logger(__name__)
@@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
@require_safetensors
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_safetensors_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".h5"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)
model.build_in_name_scope()
new_model.build_in_name_scope()
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@slow
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_sharded_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# No tf weights or index file, only a safetensors index
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
@@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# Check we have a safetensors shard index file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
# Confirm that we can access the TF weights too
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
@require_safetensors
def test_safetensors_load_from_local(self):