Flax sharded (#17760)
This commit is contained in:
@@ -13,11 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
from typing import Any, Dict, Set, Tuple, Union
|
from typing import Any, Dict, Set, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -33,6 +39,7 @@ from .dynamic_module_utils import custom_object_save
|
|||||||
from .generation_flax_utils import FlaxGenerationMixin
|
from .generation_flax_utils import FlaxGenerationMixin
|
||||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
FLAX_WEIGHTS_INDEX_NAME,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
@@ -51,6 +58,7 @@ from .utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -70,6 +78,88 @@ ACT2FN = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_byte_size(dtype):
|
||||||
|
"""
|
||||||
|
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
|
||||||
|
```py
|
||||||
|
>>> dtype_byte_size(np.float32)
|
||||||
|
4
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dtype == np.bool:
|
||||||
|
return 1 / 8
|
||||||
|
bit_search = re.search("[^\d](\d+)$", dtype.name)
|
||||||
|
if bit_search is None:
|
||||||
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||||
|
bit_size = int(bit_search.groups()[0])
|
||||||
|
return bit_size // 8
|
||||||
|
|
||||||
|
|
||||||
|
def flax_shard_checkpoint(params, max_shard_size="10GB"):
|
||||||
|
"""
|
||||||
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
|
||||||
|
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
|
||||||
|
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
|
||||||
|
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
|
||||||
|
have a size greater than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||||
|
(like `"5MB"`).
|
||||||
|
"""
|
||||||
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||||
|
|
||||||
|
sharded_state_dicts = []
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
# flatten the weights to chunk
|
||||||
|
weights = flatten_dict(params, sep="/")
|
||||||
|
for item in weights:
|
||||||
|
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
|
||||||
|
|
||||||
|
# If this weight is going to tip up over the maximal size, we split.
|
||||||
|
if current_block_size + weight_size > max_shard_size:
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block[item] = weights[item]
|
||||||
|
current_block_size += weight_size
|
||||||
|
total_size += weight_size
|
||||||
|
|
||||||
|
# Add the last block
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
|
||||||
|
# If we only have one shard, we return it
|
||||||
|
if len(sharded_state_dicts) == 1:
|
||||||
|
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
|
# Otherwise, let's build the index
|
||||||
|
weight_map = {}
|
||||||
|
shards = {}
|
||||||
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
|
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
|
||||||
|
shards[shard_file] = shard
|
||||||
|
for weight_name in shard.keys():
|
||||||
|
weight_map[weight_name] = shard_file
|
||||||
|
|
||||||
|
# Add the metadata
|
||||||
|
metadata = {"total_size": total_size}
|
||||||
|
index = {"metadata": metadata, "weight_map": weight_map}
|
||||||
|
return shards, index
|
||||||
|
|
||||||
|
|
||||||
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
@@ -333,6 +423,53 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
```"""
|
```"""
|
||||||
return self._cast_floating_to(params, jnp.float16, mask)
|
return self._cast_floating_to(params, jnp.float16, mask)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_flax_sharded_weights(cls, shard_files):
|
||||||
|
"""
|
||||||
|
This is the same as [`flax.serialization.from_bytes`]
|
||||||
|
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
|
||||||
|
|
||||||
|
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||||
|
loaded in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shard_files (`List[str]`:
|
||||||
|
The list of shard files to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
|
||||||
|
{'params': {'...'}}}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
state_sharded_dict = dict()
|
||||||
|
|
||||||
|
for shard_file in shard_files:
|
||||||
|
# load using msgpack utils
|
||||||
|
try:
|
||||||
|
with open(shard_file, "rb") as state_f:
|
||||||
|
state = from_bytes(cls, state_f.read())
|
||||||
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||||
|
with open(shard_file) as f:
|
||||||
|
if f.read().startswith("version"):
|
||||||
|
raise OSError(
|
||||||
|
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||||
|
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||||
|
" folder you cloned."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError from e
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
|
||||||
|
|
||||||
|
state = flatten_dict(state, sep="/")
|
||||||
|
state_sharded_dict.update(state)
|
||||||
|
del state
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# the state dict is unflattened to the match the format of model.params
|
||||||
|
return unflatten_dict(state_sharded_dict, sep="/")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -489,6 +626,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
# Add the dtype to model_kwargs
|
# Add the dtype to model_kwargs
|
||||||
model_kwargs["dtype"] = dtype
|
model_kwargs["dtype"] = dtype
|
||||||
|
|
||||||
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
|
# index of the files.
|
||||||
|
is_sharded = False
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
@@ -498,6 +639,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||||
# Load from a Flax checkpoint
|
# Load from a Flax checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)):
|
||||||
|
# Load from a sharded Flax checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
@@ -521,6 +666,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(
|
resolved_archive_file = cached_path(
|
||||||
archive_file,
|
archive_file,
|
||||||
@@ -548,12 +694,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
)
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
if filename == FLAX_WEIGHTS_NAME:
|
if filename == FLAX_WEIGHTS_NAME:
|
||||||
|
try:
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
archive_file = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=FLAX_WEIGHTS_INDEX_NAME,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
resolved_archive_file = cached_path(
|
||||||
|
archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
is_sharded = True
|
||||||
|
except EntryNotFoundError:
|
||||||
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
||||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to load"
|
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
||||||
" this model from those weights."
|
" load this model from those weights."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
@@ -592,14 +757,34 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
|
if is_sharded:
|
||||||
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
|
resolved_archive_file, _ = get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
# init random models
|
# init random models
|
||||||
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||||
else:
|
else:
|
||||||
with open(resolved_archive_file, "rb") as state_f:
|
|
||||||
|
if is_sharded:
|
||||||
|
state = cls.load_flax_sharded_weights(resolved_archive_file)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
|
with open(resolved_archive_file, "rb") as state_f:
|
||||||
state = from_bytes(cls, state_f.read())
|
state = from_bytes(cls, state_f.read())
|
||||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||||
try:
|
try:
|
||||||
@@ -742,7 +927,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
else:
|
else:
|
||||||
return model, unflatten_dict(state)
|
return model, unflatten_dict(state)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
|
def save_pretrained(
|
||||||
|
self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, max_shard_size="10GB", **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
|
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
|
||||||
@@ -761,6 +948,17 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||||
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||||
|
which will be bigger than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
@@ -788,11 +986,42 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
|
|
||||||
# save model
|
# save model
|
||||||
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
||||||
|
|
||||||
|
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
|
||||||
|
# Clean the folder from a previous save
|
||||||
|
for filename in os.listdir(save_directory):
|
||||||
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
if (
|
||||||
|
filename.startswith(FLAX_WEIGHTS_NAME[:-4])
|
||||||
|
and os.path.isfile(full_filename)
|
||||||
|
and filename not in shards.keys()
|
||||||
|
):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
with open(output_model_file, "wb") as f:
|
with open(output_model_file, "wb") as f:
|
||||||
params = params if params is not None else self.params
|
params = params if params is not None else self.params
|
||||||
model_bytes = to_bytes(params)
|
model_bytes = to_bytes(params)
|
||||||
f.write(model_bytes)
|
f.write(model_bytes)
|
||||||
|
|
||||||
|
else:
|
||||||
|
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
|
# Save the index as well
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
logger.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
# the shard item are unflattened, to save them we need to flatten them again
|
||||||
|
with open(os.path.join(save_directory, shard_file), mode="wb") as f:
|
||||||
|
params = unflatten_dict(shard, sep="/")
|
||||||
|
shard_bytes = to_bytes(params)
|
||||||
|
f.write(shard_bytes)
|
||||||
|
|
||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ TF2_WEIGHTS_NAME = "tf_model.h5"
|
|||||||
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
||||||
TF_WEIGHTS_NAME = "model.ckpt"
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||||
|
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
||||||
MODEL_CARD_NAME = "modelcard.json"
|
MODEL_CARD_NAME = "modelcard.json"
|
||||||
|
|||||||
@@ -937,7 +937,7 @@ class PushToHubMixin:
|
|||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
# Save the files in the cloned repo
|
# Save the files in the cloned repo
|
||||||
|
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||||
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
if hasattr(self, "history") and hasattr(self, "create_model_card"):
|
||||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||||
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
# This is a Keras model and we might be able to fish out its History and make a model card out of it
|
||||||
@@ -947,9 +947,7 @@ class PushToHubMixin:
|
|||||||
}
|
}
|
||||||
base_model_card_args.update(model_card_kwargs)
|
base_model_card_args.update(model_card_kwargs)
|
||||||
self.create_model_card(**base_model_card_args)
|
self.create_model_card(**base_model_card_args)
|
||||||
else:
|
|
||||||
# FLAX does not support sharding yet, will come in next PR
|
|
||||||
self.save_pretrained(repo_path_or_name)
|
|
||||||
# Commit and push!
|
# Commit and push!
|
||||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
|
||||||
@@ -1090,7 +1088,6 @@ def convert_file_size_to_int(size: Union[int, str]):
|
|||||||
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
>>> convert_file_size_to_int("1MiB")
|
>>> convert_file_size_to_int("1MiB")
|
||||||
1048576
|
1048576
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
@@ -45,6 +46,7 @@ if is_flax_available():
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.serialization import from_bytes
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
@@ -58,6 +60,7 @@ if is_flax_available():
|
|||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_flax_utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME
|
||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
@@ -1043,6 +1046,59 @@ class FlaxModelTesterMixin:
|
|||||||
# Check if all required parmas are loaded
|
# Check if all required parmas are loaded
|
||||||
_assert_all_params_initialised(model, params)
|
_assert_all_params_initialised(model, params)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_from_hub(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("ArthurZ/flax-tiny-random-bert-sharded")
|
||||||
|
# the model above is the same as the model below, just a sharded version.
|
||||||
|
ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
|
||||||
|
assert np.allclose(np.array(p1), np.array(p2))
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_local(self):
|
||||||
|
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".msgpack"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, FLAX_WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
|
for shard_file, size in shard_to_size.items():
|
||||||
|
if max_size.endswith("kiB"):
|
||||||
|
max_size_int = int(max_size[:-3]) * 2**10
|
||||||
|
else:
|
||||||
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
|
# the size asked for (since we count parameters)
|
||||||
|
if size >= max_size_int + 50000:
|
||||||
|
with open(shard_file, "rb") as state_f:
|
||||||
|
state_file = from_bytes(FlaxBertModel, state_f.read())
|
||||||
|
self.assertEqual(len(state_file), 1)
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".msgpack"))
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||||
|
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user