🧼 remove v4.44 deprecations (#34245)
* remove v4.44 deprecations * PR comments * deprecations scheduled for v4.50 * hub version update * make fiuxp --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -105,7 +105,6 @@ if is_torch_available():
|
||||
_find_disjoint,
|
||||
_find_identical,
|
||||
dtype_byte_size,
|
||||
shard_checkpoint,
|
||||
)
|
||||
from transformers.pytorch_utils import isin_mps_friendly
|
||||
|
||||
@@ -668,71 +667,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(100, 200, bias=False), # size 80,000
|
||||
torch.nn.Linear(200, 200, bias=False), # size 160,000
|
||||
torch.nn.Linear(200, 100, bias=False), # size 80,000
|
||||
torch.nn.Linear(100, 50, bias=False), # size 20,000
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = shard_checkpoint(state_dict)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"1.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"2.weight": "pytorch_model-00002-of-00002.bin",
|
||||
"3.weight": "pytorch_model-00002-of-00002.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
|
||||
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
|
||||
)
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00003.bin",
|
||||
"1.weight": "pytorch_model-00002-of-00003.bin",
|
||||
"2.weight": "pytorch_model-00003-of-00003.bin",
|
||||
"3.weight": "pytorch_model-00003-of-00003.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"]}
|
||||
shard2 = {"1.weight": state_dict["1.weight"]}
|
||||
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"pytorch_model-00001-of-00003.bin": shard1,
|
||||
"pytorch_model-00002-of-00003.bin": shard2,
|
||||
"pytorch_model-00003-of-00003.bin": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local_bin(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user