Checkpoint sharding (#16343)
* Sharded checkpoint support * Handle distant sharded checkpoints * Add tests * TODO is done * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix docstring * Add example and format * Address review comments * More review comments * End of merge * Revert unintentional change * VsCode what did you do? * Style * Changes * Address final comments * Quality * Moar tests * Move import beneath is_pt_available Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -52,6 +52,7 @@ from .utils import (
|
|||||||
USE_JAX,
|
USE_JAX,
|
||||||
USE_TF,
|
USE_TF,
|
||||||
USE_TORCH,
|
USE_TORCH,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
ContextManagers,
|
ContextManagers,
|
||||||
DummyObject,
|
DummyObject,
|
||||||
|
|||||||
@@ -15,11 +15,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -38,6 +42,7 @@ from .utils import (
|
|||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
TF_WEIGHTS_NAME,
|
TF_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
@@ -45,7 +50,6 @@ from .utils import (
|
|||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
|
||||||
has_file,
|
has_file,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
@@ -148,6 +152,272 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
|||||||
return first_tuple[1].dtype
|
return first_tuple[1].dtype
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file_size_to_int(size: Union[int, str]):
|
||||||
|
"""
|
||||||
|
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> convert_file_size_to_int("1MB")
|
||||||
|
1048576
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(size, int):
|
||||||
|
return size
|
||||||
|
if size.upper().endswith("GIB"):
|
||||||
|
return int(size[:-3]) * (2**30)
|
||||||
|
if size.upper().endswith("MIB"):
|
||||||
|
return int(size[:-3]) * (2**20)
|
||||||
|
if size.upper().endswith("KIB"):
|
||||||
|
return int(size[:-3]) * (2**10)
|
||||||
|
if size.upper().endswith("GB"):
|
||||||
|
return int(size[:-2]) * (10**9)
|
||||||
|
if size.upper().endswith("MB"):
|
||||||
|
return int(size[:-2]) * (10**6)
|
||||||
|
if size.upper().endswith("KB"):
|
||||||
|
return int(size[:-2]) * (10**3)
|
||||||
|
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
||||||
|
|
||||||
|
|
||||||
|
def dtype_byte_size(dtype):
|
||||||
|
"""
|
||||||
|
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> dtype_byte_size(torch.float32)
|
||||||
|
4
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dtype == torch.bool:
|
||||||
|
return 1 / 8
|
||||||
|
bit_search = re.search("[^\d](\d+)$", str(dtype))
|
||||||
|
if bit_search is None:
|
||||||
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||||
|
bit_size = int(bit_search.groups()[0])
|
||||||
|
return bit_size // 8
|
||||||
|
|
||||||
|
|
||||||
|
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
|
||||||
|
"""
|
||||||
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
|
given size.
|
||||||
|
|
||||||
|
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
|
||||||
|
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
|
||||||
|
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
|
||||||
|
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
|
||||||
|
have a size greater than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||||
|
(like `"5MB"`).
|
||||||
|
"""
|
||||||
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||||
|
|
||||||
|
sharded_state_dicts = []
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for key, weight in state_dict.items():
|
||||||
|
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
||||||
|
|
||||||
|
# If this weight is going to tip up over the maximal size, we split.
|
||||||
|
if current_block_size + weight_size > max_shard_size:
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
current_block = {}
|
||||||
|
current_block_size = 0
|
||||||
|
|
||||||
|
current_block[key] = weight
|
||||||
|
current_block_size += weight_size
|
||||||
|
total_size += weight_size
|
||||||
|
|
||||||
|
# Add the last block
|
||||||
|
sharded_state_dicts.append(current_block)
|
||||||
|
|
||||||
|
# If we only have one shard, we return it
|
||||||
|
if len(sharded_state_dicts) == 1:
|
||||||
|
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
|
# Otherwise, let's build the index
|
||||||
|
weight_map = {}
|
||||||
|
shards = {}
|
||||||
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
|
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||||
|
shards[shard_file] = shard
|
||||||
|
for key in shard.keys():
|
||||||
|
weight_map[key] = shard_file
|
||||||
|
|
||||||
|
# Add the metadata
|
||||||
|
metadata = {"total_size": total_size}
|
||||||
|
index = {"metadata": metadata, "weight_map": weight_map}
|
||||||
|
return shards, index
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
index_filename,
|
||||||
|
cache_dir=None,
|
||||||
|
force_download=False,
|
||||||
|
proxies=None,
|
||||||
|
resume_download=False,
|
||||||
|
local_files_only=False,
|
||||||
|
use_auth_token=None,
|
||||||
|
user_agent=None,
|
||||||
|
revision=None,
|
||||||
|
mirror=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For a given model:
|
||||||
|
|
||||||
|
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
|
||||||
|
Hub
|
||||||
|
- returns the list of paths to all the shards, as well as some metadata.
|
||||||
|
|
||||||
|
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
||||||
|
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
||||||
|
"""
|
||||||
|
with open(index_filename, "r") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
shard_filenames = sorted(list(set(index["weight_map"].values())))
|
||||||
|
sharded_metadata = index["metadata"]
|
||||||
|
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||||
|
|
||||||
|
# First, let's deal with local folder.
|
||||||
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
||||||
|
return shard_filenames, sharded_metadata
|
||||||
|
|
||||||
|
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||||
|
cached_filenames = []
|
||||||
|
for shard_filename in shard_filenames:
|
||||||
|
shard_url = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load from URL
|
||||||
|
cached_filename = cached_path(
|
||||||
|
shard_url,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||||
|
# we don't have to catch them here.
|
||||||
|
except EntryNotFoundError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
||||||
|
"required according to the checkpoint index."
|
||||||
|
)
|
||||||
|
except HTTPError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"We couldn't connect to 'https://huggingface.co/' to load {shard_filename}. You should try again "
|
||||||
|
"after checking your internet connection."
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_filenames.append(cached_filename)
|
||||||
|
|
||||||
|
return cached_filenames, sharded_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||||
|
"""
|
||||||
|
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return torch.load(checkpoint_file, map_location="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
try:
|
||||||
|
with open(checkpoint_file) as f:
|
||||||
|
if f.read().startswith("version"):
|
||||||
|
raise OSError(
|
||||||
|
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||||
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||||
|
"you cloned."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
||||||
|
"model. Make sure you have saved the model properly."
|
||||||
|
) from e
|
||||||
|
except (UnicodeDecodeError, ValueError):
|
||||||
|
raise OSError(
|
||||||
|
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
||||||
|
f"at '{checkpoint_file}'. "
|
||||||
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||||
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
|
old_keys = []
|
||||||
|
new_keys = []
|
||||||
|
for key in state_dict.keys():
|
||||||
|
new_key = None
|
||||||
|
if "gamma" in key:
|
||||||
|
new_key = key.replace("gamma", "weight")
|
||||||
|
if "beta" in key:
|
||||||
|
new_key = key.replace("beta", "bias")
|
||||||
|
if new_key:
|
||||||
|
old_keys.append(key)
|
||||||
|
new_keys.append(new_key)
|
||||||
|
for old_key, new_key in zip(old_keys, new_keys):
|
||||||
|
state_dict[new_key] = state_dict.pop(old_key)
|
||||||
|
|
||||||
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
if metadata is not None:
|
||||||
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
|
error_msgs = []
|
||||||
|
|
||||||
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||||
|
# so we need to apply the function recursively.
|
||||||
|
def load(module: nn.Module, prefix=""):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
# because zero3 puts placeholders in model params, this context
|
||||||
|
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||||
|
# the state dict and then re-partitions them again
|
||||||
|
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
module._load_from_state_dict(*args)
|
||||||
|
else:
|
||||||
|
module._load_from_state_dict(*args)
|
||||||
|
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
load(child, prefix + name + ".")
|
||||||
|
|
||||||
|
load(model_to_load, prefix=start_prefix)
|
||||||
|
|
||||||
|
return error_msgs
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||||
@@ -1004,6 +1274,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
save_function: Callable = torch.save,
|
save_function: Callable = torch.save,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
|
max_shard_size: Union[int, str] = "10GB",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1035,6 +1306,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||||
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||||
|
which will be bigger than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
@@ -1078,11 +1360,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if ignore_key in state_dict.keys():
|
if ignore_key in state_dict.keys():
|
||||||
del state_dict[ignore_key]
|
del state_dict[ignore_key]
|
||||||
|
|
||||||
# If we save using the predefined names, we can load using `from_pretrained`
|
# Shard the model if it is too big.
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||||
save_function(state_dict, output_model_file)
|
|
||||||
|
|
||||||
logger.info(f"Model weights saved in {output_model_file}")
|
# Clean the folder from a previous save
|
||||||
|
for filename in os.listdir(save_directory):
|
||||||
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
save_function(shard, os.path.join(save_directory, shard_file))
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||||
|
else:
|
||||||
|
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
|
||||||
|
# Save the index as well
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
logger.info(
|
||||||
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
|
f"index located at {save_index_file}."
|
||||||
|
)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
@@ -1293,6 +1596,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||||
|
# index of the files.
|
||||||
|
is_sharded = False
|
||||||
|
sharded_metadata = None
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
@@ -1309,6 +1616,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||||
|
# Load from a sharded PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif os.path.isfile(
|
elif os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||||
@@ -1382,29 +1693,51 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
if filename == WEIGHTS_NAME:
|
if filename == WEIGHTS_NAME:
|
||||||
has_file_kwargs = {
|
try:
|
||||||
"revision": revision,
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
"mirror": mirror,
|
archive_file = hf_bucket_url(
|
||||||
"proxies": proxies,
|
pretrained_model_name_or_path,
|
||||||
"use_auth_token": use_auth_token,
|
filename=WEIGHTS_INDEX_NAME,
|
||||||
}
|
revision=revision,
|
||||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
mirror=mirror,
|
||||||
raise EnvironmentError(
|
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
|
||||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
|
||||||
"weights."
|
|
||||||
)
|
)
|
||||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
resolved_archive_file = cached_path(
|
||||||
raise EnvironmentError(
|
archive_file,
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
cache_dir=cache_dir,
|
||||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
force_download=force_download,
|
||||||
"weights."
|
proxies=proxies,
|
||||||
)
|
resume_download=resume_download,
|
||||||
else:
|
local_files_only=local_files_only,
|
||||||
raise EnvironmentError(
|
use_auth_token=use_auth_token,
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
user_agent=user_agent,
|
||||||
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
|
||||||
)
|
)
|
||||||
|
is_sharded = True
|
||||||
|
except EntryNotFoundError:
|
||||||
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||||
|
# message.
|
||||||
|
has_file_kwargs = {
|
||||||
|
"revision": revision,
|
||||||
|
"mirror": mirror,
|
||||||
|
"proxies": proxies,
|
||||||
|
"use_auth_token": use_auth_token,
|
||||||
|
}
|
||||||
|
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||||
|
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
|
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||||
|
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||||
|
"weights."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
||||||
|
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||||
@@ -1439,29 +1772,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
|
if is_sharded:
|
||||||
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
resolved_archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
user_agent=user_agent,
|
||||||
|
revision=revision,
|
||||||
|
mirror=mirror,
|
||||||
|
)
|
||||||
|
|
||||||
# load pt weights early so that we know which dtype to init the model under
|
# load pt weights early so that we know which dtype to init the model under
|
||||||
if from_pt:
|
if from_pt:
|
||||||
if state_dict is None:
|
if not is_sharded:
|
||||||
try:
|
# Time to load the checkpoint
|
||||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
state_dict = load_state_dict(resolved_archive_file)
|
||||||
except Exception as e:
|
|
||||||
try:
|
|
||||||
with open(resolved_archive_file) as f:
|
|
||||||
if f.read().startswith("version"):
|
|
||||||
raise OSError(
|
|
||||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
|
||||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
|
||||||
"you cloned."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError from e
|
|
||||||
except (UnicodeDecodeError, ValueError):
|
|
||||||
raise OSError(
|
|
||||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
|
||||||
f"at '{resolved_archive_file}'. "
|
|
||||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
|
||||||
)
|
|
||||||
|
|
||||||
# set dtype to instantiate the model under:
|
# set dtype to instantiate the model under:
|
||||||
# 1. If torch_dtype is not None, we use that dtype
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
@@ -1471,7 +1803,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
if isinstance(torch_dtype, str):
|
if isinstance(torch_dtype, str):
|
||||||
if torch_dtype == "auto":
|
if torch_dtype == "auto":
|
||||||
torch_dtype = next(iter(state_dict.values())).dtype
|
if is_sharded and "dtype" in sharded_metadata:
|
||||||
|
torch_dtype = sharded_metadata["dtype"]
|
||||||
|
elif not is_sharded:
|
||||||
|
torch_dtype = next(iter(state_dict.values())).dtype
|
||||||
|
else:
|
||||||
|
one_state_dict = load_state_dict(resolved_archive_file)
|
||||||
|
torch_dtype = next(iter(one_state_dict.values())).dtype
|
||||||
|
del one_state_dict # free CPU memory
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
||||||
@@ -1480,8 +1819,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
# save the keys
|
# save the keys
|
||||||
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
if is_sharded:
|
||||||
del state_dict # free CPU memory - will reload again later
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(resolved_archive_file)
|
||||||
|
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
||||||
|
del state_dict # free CPU memory - will reload again later
|
||||||
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
config.name_or_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
@@ -1534,13 +1877,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif from_pt:
|
elif from_pt:
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
||||||
else:
|
else:
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
model,
|
model,
|
||||||
state_dict,
|
state_dict,
|
||||||
|
resolved_archive_file,
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
sharded_metadata=sharded_metadata,
|
||||||
_fast_init=_fast_init,
|
_fast_init=_fast_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1562,31 +1907,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_state_dict_into_model(
|
def _load_pretrained_model(
|
||||||
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
|
cls,
|
||||||
|
model,
|
||||||
|
state_dict,
|
||||||
|
resolved_archive_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
ignore_mismatched_sizes=False,
|
||||||
|
sharded_metadata=None,
|
||||||
|
_fast_init=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
|
||||||
old_keys = []
|
|
||||||
new_keys = []
|
|
||||||
for key in state_dict.keys():
|
|
||||||
new_key = None
|
|
||||||
if "gamma" in key:
|
|
||||||
new_key = key.replace("gamma", "weight")
|
|
||||||
if "beta" in key:
|
|
||||||
new_key = key.replace("beta", "bias")
|
|
||||||
if new_key:
|
|
||||||
old_keys.append(key)
|
|
||||||
new_keys.append(new_key)
|
|
||||||
for old_key, new_key in zip(old_keys, new_keys):
|
|
||||||
state_dict[new_key] = state_dict.pop(old_key)
|
|
||||||
|
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
loaded_keys = list(state_dict.keys())
|
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
|
||||||
prefix = model.base_model_prefix
|
prefix = model.base_model_prefix
|
||||||
|
|
||||||
|
def _fix_key(key):
|
||||||
|
if "beta" in key:
|
||||||
|
return key.replace("beta", "bias")
|
||||||
|
if "gamma" in key:
|
||||||
|
return key.replace("gamma", "weight")
|
||||||
|
return key
|
||||||
|
|
||||||
|
loaded_keys = [_fix_key(key) for key in loaded_keys]
|
||||||
|
|
||||||
if len(prefix) > 0:
|
if len(prefix) > 0:
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
||||||
@@ -1608,28 +1953,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||||
|
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
||||||
# matching the weights in the model.
|
|
||||||
mismatched_keys = []
|
|
||||||
if ignore_mismatched_sizes:
|
|
||||||
for checkpoint_key in loaded_keys:
|
|
||||||
model_key = checkpoint_key
|
|
||||||
if remove_prefix_from_model:
|
|
||||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
|
||||||
model_key = f"{prefix}.{checkpoint_key}"
|
|
||||||
elif add_prefix_to_model:
|
|
||||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
|
||||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_key in model_state_dict
|
|
||||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
|
||||||
):
|
|
||||||
mismatched_keys.append(
|
|
||||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
|
||||||
)
|
|
||||||
del state_dict[checkpoint_key]
|
|
||||||
|
|
||||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||||
# the user.
|
# the user.
|
||||||
if cls._keys_to_ignore_on_load_missing is not None:
|
if cls._keys_to_ignore_on_load_missing is not None:
|
||||||
@@ -1648,35 +1971,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
for module in uninitialized_modules:
|
for module in uninitialized_modules:
|
||||||
model._init_weights(module)
|
model._init_weights(module)
|
||||||
|
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
|
||||||
metadata = getattr(state_dict, "_metadata", None)
|
|
||||||
state_dict = state_dict.copy()
|
|
||||||
if metadata is not None:
|
|
||||||
state_dict._metadata = metadata
|
|
||||||
|
|
||||||
error_msgs = []
|
|
||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
||||||
# so we need to apply the function recursively.
|
|
||||||
def load(module: nn.Module, prefix=""):
|
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
||||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
# because zero3 puts placeholders in model params, this context
|
|
||||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
|
||||||
# the state dict and then re-partitions them again
|
|
||||||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
|
||||||
if torch.distributed.get_rank() == 0:
|
|
||||||
module._load_from_state_dict(*args)
|
|
||||||
else:
|
|
||||||
module._load_from_state_dict(*args)
|
|
||||||
|
|
||||||
for name, child in module._modules.items():
|
|
||||||
if child is not None:
|
|
||||||
load(child, prefix + name + ".")
|
|
||||||
|
|
||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
# Make sure we are able to load base models as well as derived models (with heads)
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = model
|
model_to_load = model
|
||||||
@@ -1690,7 +1984,61 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"properly saved?"
|
"properly saved?"
|
||||||
)
|
)
|
||||||
|
|
||||||
load(model_to_load, prefix=start_prefix)
|
if state_dict is not None:
|
||||||
|
# Whole checkpoint
|
||||||
|
mismatched_keys = []
|
||||||
|
if ignore_mismatched_sizes:
|
||||||
|
for checkpoint_key in loaded_keys:
|
||||||
|
model_key = checkpoint_key
|
||||||
|
if remove_prefix_from_model:
|
||||||
|
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||||
|
model_key = f"{prefix}.{checkpoint_key}"
|
||||||
|
elif add_prefix_to_model:
|
||||||
|
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||||
|
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_key in model_state_dict
|
||||||
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||||
|
):
|
||||||
|
mismatched_keys.append(
|
||||||
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||||
|
)
|
||||||
|
del state_dict[checkpoint_key]
|
||||||
|
|
||||||
|
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||||
|
else:
|
||||||
|
# Sharded checkpoint
|
||||||
|
# This should always be a list but, just to be sure.
|
||||||
|
if not isinstance(resolved_archive_file, list):
|
||||||
|
resolved_archive_file = [resolved_archive_file]
|
||||||
|
|
||||||
|
error_msgs = []
|
||||||
|
for shard_file in resolved_archive_file:
|
||||||
|
state_dict = load_state_dict(shard_file)
|
||||||
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
|
# matching the weights in the model.
|
||||||
|
mismatched_keys = []
|
||||||
|
if ignore_mismatched_sizes:
|
||||||
|
for checkpoint_key in loaded_keys:
|
||||||
|
model_key = checkpoint_key
|
||||||
|
if remove_prefix_from_model:
|
||||||
|
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||||
|
model_key = f"{prefix}.{checkpoint_key}"
|
||||||
|
elif add_prefix_to_model:
|
||||||
|
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||||
|
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_key in model_state_dict
|
||||||
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||||
|
):
|
||||||
|
mismatched_keys.append(
|
||||||
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||||
|
)
|
||||||
|
del state_dict[checkpoint_key]
|
||||||
|
|
||||||
|
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||||
|
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
error_msg = "\n\t".join(error_msgs)
|
error_msg = "\n\t".join(error_msgs)
|
||||||
@@ -1755,7 +2103,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return retrieved_modules
|
return retrieved_modules
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
||||||
"""
|
"""
|
||||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||||
|
|
||||||
@@ -1772,7 +2120,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
require_version_core("torch>=1.9")
|
require_version_core("torch>=1.9")
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
|
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
|
||||||
@@ -1806,19 +2153,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
new_val = new_val.to("meta")
|
new_val = new_val.to("meta")
|
||||||
setattr(submodule, param_name, new_val)
|
setattr(submodule, param_name, new_val)
|
||||||
|
|
||||||
# only now can load state_dict
|
# only now can load state_dict(s)
|
||||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
if not isinstance(resolved_archive_file, list):
|
||||||
|
resolved_archive_file = [resolved_archive_file]
|
||||||
|
|
||||||
# materialize state_dict entries one by one on CPU
|
for archive_file in resolved_archive_file:
|
||||||
for k in loaded_state_dict_keys:
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||||
submodule, param_name = find_submodule_and_param_name(model, k)
|
|
||||||
if submodule is not None:
|
|
||||||
new_val = state_dict[k]
|
|
||||||
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
|
||||||
new_val = torch.nn.Parameter(new_val)
|
|
||||||
setattr(submodule, param_name, new_val)
|
|
||||||
|
|
||||||
del state_dict
|
# materialize state_dict entries one by one on CPU
|
||||||
|
for k in loaded_state_dict_keys:
|
||||||
|
submodule, param_name = find_submodule_and_param_name(model, k)
|
||||||
|
if submodule is not None:
|
||||||
|
new_val = state_dict[k]
|
||||||
|
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
||||||
|
new_val = torch.nn.Parameter(new_val)
|
||||||
|
setattr(submodule, param_name, new_val)
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||||
@@ -1846,12 +2197,109 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
cls._auto_class = auto_class
|
cls._auto_class = auto_class
|
||||||
|
|
||||||
|
def push_to_hub(
|
||||||
|
self,
|
||||||
|
repo_path_or_name: Optional[str] = None,
|
||||||
|
repo_url: Optional[str] = None,
|
||||||
|
use_temp_dir: bool = False,
|
||||||
|
commit_message: str = "add model",
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
private: Optional[bool] = None,
|
||||||
|
use_auth_token: Optional[Union[bool, str]] = None,
|
||||||
|
max_shard_size: Union[int, str] = "10GB",
|
||||||
|
**model_card_kwargs
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||||
|
|
||||||
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
|
Parameters:
|
||||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
repo_path_or_name (`str`, *optional*):
|
||||||
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
|
Can either be a repository name for your model in the Hub or a path to a local folder (in which case
|
||||||
object="model", object_class="AutoModel", object_files="model checkpoint"
|
the repository will have the name of that local folder). If not specified, will default to the name
|
||||||
)
|
given by `repo_url` and a local directory with that name will be created.
|
||||||
|
repo_url (`str`, *optional*):
|
||||||
|
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
||||||
|
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
|
||||||
|
use_temp_dir (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
|
||||||
|
current working directory. This will slow things down if you are making changes in an existing repo
|
||||||
|
since you will need to clone the repo before every push.
|
||||||
|
commit_message (`str`, *optional*, defaults to `"add model"`):
|
||||||
|
Message to commit while pushing.
|
||||||
|
organization (`str`, *optional*):
|
||||||
|
Organization in which you want to push your {object} (you must be a member of this organization).
|
||||||
|
private (`bool`, *optional*):
|
||||||
|
Whether or not the repository created should be private (requires a paying subscription).
|
||||||
|
use_auth_token (`bool` or `str`, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||||
|
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
||||||
|
`repo_url` is not specified.
|
||||||
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||||
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||||
|
which will be bigger than `max_shard_size`.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: The url of the commit of your {object} in the given repository.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
||||||
|
# *my-finetuned-bert* folder.
|
||||||
|
model.push_to_hub("my-finetuned-bert")
|
||||||
|
|
||||||
|
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
|
||||||
|
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
||||||
|
|
||||||
|
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
|
||||||
|
# *my-finetuned-bert* folder.
|
||||||
|
model.push_to_hub("my-finetuned-bert", organization="huggingface")
|
||||||
|
|
||||||
|
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
|
||||||
|
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if use_temp_dir:
|
||||||
|
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
||||||
|
if repo_url is None:
|
||||||
|
if use_auth_token is None:
|
||||||
|
use_auth_token = True
|
||||||
|
repo_name = Path(repo_path_or_name).name
|
||||||
|
repo_url = self._get_repo_url_from_name(
|
||||||
|
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||||
|
)
|
||||||
|
repo_path_or_name = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
||||||
|
repo = self._create_or_get_repo(
|
||||||
|
repo_path_or_name=repo_path_or_name,
|
||||||
|
repo_url=repo_url,
|
||||||
|
organization=organization,
|
||||||
|
private=private,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
|
# Save the files in the cloned repo
|
||||||
|
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||||
|
|
||||||
|
# Commit and push!
|
||||||
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
|
||||||
|
# Clean up! Clean up! Everybody everywhere!
|
||||||
|
if use_temp_dir:
|
||||||
|
shutil.rmtree(repo_path_or_name)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ from .import_utils import (
|
|||||||
|
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
TF_WEIGHTS_NAME = "model.ckpt"
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
@@ -90,6 +90,7 @@ if is_torch_available():
|
|||||||
T5Config,
|
T5Config,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -2352,6 +2353,123 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_shard_checkpoint(self):
|
||||||
|
# This is the model we will use, total size 340,000 bytes.
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(100, 200, bias=False), # size 80,000
|
||||||
|
torch.nn.Linear(200, 200, bias=False), # size 160,000
|
||||||
|
torch.nn.Linear(200, 100, bias=False), # size 80,000
|
||||||
|
torch.nn.Linear(100, 50, bias=False), # size 20,000
|
||||||
|
)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
with self.subTest("No shard when max size is bigger than model size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict)
|
||||||
|
self.assertIsNone(index)
|
||||||
|
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
|
||||||
|
# Split is first two layers then last two.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"0.weight": "pytorch_model-00001-of-00002.bin",
|
||||||
|
"1.weight": "pytorch_model-00001-of-00002.bin",
|
||||||
|
"2.weight": "pytorch_model-00002-of-00002.bin",
|
||||||
|
"3.weight": "pytorch_model-00002-of-00002.bin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
|
||||||
|
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("Test sharding with weights bigger than max size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
|
||||||
|
# Split is first layer, second layer then last 2.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"0.weight": "pytorch_model-00001-of-00003.bin",
|
||||||
|
"1.weight": "pytorch_model-00002-of-00003.bin",
|
||||||
|
"2.weight": "pytorch_model-00003-of-00003.bin",
|
||||||
|
"3.weight": "pytorch_model-00003-of-00003.bin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = {"0.weight": state_dict["0.weight"]}
|
||||||
|
shard2 = {"1.weight": state_dict["1.weight"]}
|
||||||
|
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards,
|
||||||
|
{
|
||||||
|
"pytorch_model-00001-of-00003.bin": shard1,
|
||||||
|
"pytorch_model-00002-of-00003.bin": shard2,
|
||||||
|
"pytorch_model-00003-of-00003.bin": shard3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_local(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".bin"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
|
for shard_file, size in shard_to_size.items():
|
||||||
|
if max_size.endswith("kiB"):
|
||||||
|
max_size_int = int(max_size[:-3]) * 2**10
|
||||||
|
else:
|
||||||
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
|
# the size asked for (since we count parameters)
|
||||||
|
if size >= max_size_int + 50000:
|
||||||
|
state_dict = torch.load(shard_file)
|
||||||
|
self.assertEqual(len(state_dict), 1)
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_from_hub(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
# the model above is the same as the model below, just a sharded version.
|
||||||
|
ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
# A mock response for an HTTP head request to emulate server down
|
# A mock response for an HTTP head request to emulate server down
|
||||||
response_mock = mock.Mock()
|
response_mock = mock.Mock()
|
||||||
|
|||||||
Reference in New Issue
Block a user