Flax sharded (#17760)

This commit is contained in:
Arthur
2022-06-22 07:04:35 +02:00
committed by GitHub
parent 3b00b623b7
commit 16c6eb7ca1
4 changed files with 305 additions and 22 deletions

View File

@@ -14,6 +14,7 @@
import copy
import inspect
import json
import random
import tempfile
import unittest
@@ -45,6 +46,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.serialization import from_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -58,6 +60,7 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
# Check if all required parmas are loaded
_assert_all_params_initialised(model, params)
def test_checkpoint_sharding_from_hub(self):
model = FlaxBertModel.from_pretrained("ArthurZ/flax-tiny-random-bert-sharded")
# the model above is the same as the model below, just a sharded version.
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
assert np.allclose(np.array(p1), np.array(p2))
def test_checkpoint_sharding_local(self):
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".msgpack"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, FLAX_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
with open(shard_file, "rb") as state_f:
state_file = from_bytes(FlaxBertModel, state_f.read())
self.assertEqual(len(state_file), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".msgpack"))
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = FlaxBertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
@require_flax
@is_staging_test