Flax sharded (#17760)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -45,6 +46,7 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.serialization import from_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
@@ -58,6 +60,7 @@ if is_flax_available():
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
|
||||
# Check if all required parmas are loaded
|
||||
_assert_all_params_initialised(model, params)
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = FlaxBertModel.from_pretrained("ArthurZ/flax-tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
|
||||
assert np.allclose(np.array(p1), np.array(p2))
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".msgpack"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, FLAX_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
with open(shard_file, "rb") as state_f:
|
||||
state_file = from_bytes(FlaxBertModel, state_f.read())
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".msgpack"))
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user