TF Sharded (#17713)
* initial commit * update modeeling tf utils * quality * clean and update args * update * remove potential bug * code quality * update * update max shard * update tests for sharding from pretrained * fix remaining test * make style * h5py if tf available * update and fix test * fix test * style * modified push to hub to support shard for TF * quick fix * update code * merge branch main and style * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update based on reviews * update doc * update and style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update based on reviews * fix typo * style Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -85,7 +86,12 @@ if is_tf_available():
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
from transformers.modeling_tf_utils import unpack_inputs
|
||||
from transformers.modeling_tf_utils import (
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
tf_shard_checkpoint,
|
||||
unpack_inputs,
|
||||
)
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
@@ -1867,6 +1873,129 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
out = masked_softmax(x, boolean_mask)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
|
||||
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
|
||||
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
|
||||
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
|
||||
]
|
||||
)
|
||||
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
||||
model(inputs)
|
||||
weights = model.weights
|
||||
weights_dict = {w.name: w for w in weights}
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = tf_shard_checkpoint(weights)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
||||
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
||||
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_1/kernel:0"]]
|
||||
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"tf_model-00001-of-00003.h5": shard1,
|
||||
"tf_model-00002-of-00003.h5": shard2,
|
||||
"tf_model-00003-of-00003.h5": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
with h5py.File(shard_file, "r") as state_file:
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".h5"))
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model(model.dummy_inputs)
|
||||
new_model(model.dummy_inputs)
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user