Checkpoint sharding (#16343)
* Sharded checkpoint support * Handle distant sharded checkpoints * Add tests * TODO is done * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix docstring * Add example and format * Address review comments * More review comments * End of merge * Revert unintentional change * VsCode what did you do? * Style * Changes * Address final comments * Quality * Moar tests * Move import beneath is_pt_available Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -90,6 +90,7 @@ if is_torch_available():
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
@@ -2352,6 +2353,123 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(100, 200, bias=False), # size 80,000
|
||||
torch.nn.Linear(200, 200, bias=False), # size 160,000
|
||||
torch.nn.Linear(200, 100, bias=False), # size 80,000
|
||||
torch.nn.Linear(100, 50, bias=False), # size 20,000
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = shard_checkpoint(state_dict)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"1.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"2.weight": "pytorch_model-00002-of-00002.bin",
|
||||
"3.weight": "pytorch_model-00002-of-00002.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
|
||||
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
|
||||
)
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00003.bin",
|
||||
"1.weight": "pytorch_model-00002-of-00003.bin",
|
||||
"2.weight": "pytorch_model-00003-of-00003.bin",
|
||||
"3.weight": "pytorch_model-00003-of-00003.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"]}
|
||||
shard2 = {"1.weight": state_dict["1.weight"]}
|
||||
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"pytorch_model-00001-of-00003.bin": shard1,
|
||||
"pytorch_model-00002-of-00003.bin": shard2,
|
||||
"pytorch_model-00003-of-00003.bin": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".bin"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
state_dict = torch.load(shard_file)
|
||||
self.assertEqual(len(state_dict), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
||||
Reference in New Issue
Block a user