Use huggingface_hub helper function to split state dict (#31091)
* shard saving from hf hub * index = None * fix tests * indent
This commit is contained in:
@@ -669,7 +669,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||
for max_size in ["50kB", "100kB", "200kB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
|
||||
|
||||
# Get each shard file and its size
|
||||
@@ -686,10 +686,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
|
||||
Reference in New Issue
Block a user