|
|
|
|
@@ -15,11 +15,15 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import shutil
|
|
|
|
|
import tempfile
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -38,6 +42,7 @@ from .utils import (
|
|
|
|
|
FLAX_WEIGHTS_NAME,
|
|
|
|
|
TF2_WEIGHTS_NAME,
|
|
|
|
|
TF_WEIGHTS_NAME,
|
|
|
|
|
WEIGHTS_INDEX_NAME,
|
|
|
|
|
WEIGHTS_NAME,
|
|
|
|
|
EntryNotFoundError,
|
|
|
|
|
ModelOutput,
|
|
|
|
|
@@ -45,7 +50,6 @@ from .utils import (
|
|
|
|
|
RepositoryNotFoundError,
|
|
|
|
|
RevisionNotFoundError,
|
|
|
|
|
cached_path,
|
|
|
|
|
copy_func,
|
|
|
|
|
has_file,
|
|
|
|
|
hf_bucket_url,
|
|
|
|
|
is_offline_mode,
|
|
|
|
|
@@ -148,6 +152,272 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
|
|
|
|
return first_tuple[1].dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_file_size_to_int(size: Union[int, str]):
|
|
|
|
|
"""
|
|
|
|
|
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```py
|
|
|
|
|
>>> convert_file_size_to_int("1MB")
|
|
|
|
|
1048576
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(size, int):
|
|
|
|
|
return size
|
|
|
|
|
if size.upper().endswith("GIB"):
|
|
|
|
|
return int(size[:-3]) * (2**30)
|
|
|
|
|
if size.upper().endswith("MIB"):
|
|
|
|
|
return int(size[:-3]) * (2**20)
|
|
|
|
|
if size.upper().endswith("KIB"):
|
|
|
|
|
return int(size[:-3]) * (2**10)
|
|
|
|
|
if size.upper().endswith("GB"):
|
|
|
|
|
return int(size[:-2]) * (10**9)
|
|
|
|
|
if size.upper().endswith("MB"):
|
|
|
|
|
return int(size[:-2]) * (10**6)
|
|
|
|
|
if size.upper().endswith("KB"):
|
|
|
|
|
return int(size[:-2]) * (10**3)
|
|
|
|
|
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dtype_byte_size(dtype):
|
|
|
|
|
"""
|
|
|
|
|
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
|
|
```py
|
|
|
|
|
>>> dtype_byte_size(torch.float32)
|
|
|
|
|
4
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
if dtype == torch.bool:
|
|
|
|
|
return 1 / 8
|
|
|
|
|
bit_search = re.search("[^\d](\d+)$", str(dtype))
|
|
|
|
|
if bit_search is None:
|
|
|
|
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
|
|
|
|
bit_size = int(bit_search.groups()[0])
|
|
|
|
|
return bit_size // 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
|
|
|
|
|
"""
|
|
|
|
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
|
|
|
|
given size.
|
|
|
|
|
|
|
|
|
|
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
|
|
|
|
|
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
|
|
|
|
|
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
|
|
|
|
|
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
|
|
|
|
|
|
|
|
|
<Tip warning={true}>
|
|
|
|
|
|
|
|
|
|
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
|
|
|
|
|
have a size greater than `max_shard_size`.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
|
|
|
|
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
|
|
|
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
|
|
|
|
(like `"5MB"`).
|
|
|
|
|
"""
|
|
|
|
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
|
|
|
|
|
|
|
|
|
sharded_state_dicts = []
|
|
|
|
|
current_block = {}
|
|
|
|
|
current_block_size = 0
|
|
|
|
|
total_size = 0
|
|
|
|
|
|
|
|
|
|
for key, weight in state_dict.items():
|
|
|
|
|
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
|
|
|
|
|
|
|
|
|
# If this weight is going to tip up over the maximal size, we split.
|
|
|
|
|
if current_block_size + weight_size > max_shard_size:
|
|
|
|
|
sharded_state_dicts.append(current_block)
|
|
|
|
|
current_block = {}
|
|
|
|
|
current_block_size = 0
|
|
|
|
|
|
|
|
|
|
current_block[key] = weight
|
|
|
|
|
current_block_size += weight_size
|
|
|
|
|
total_size += weight_size
|
|
|
|
|
|
|
|
|
|
# Add the last block
|
|
|
|
|
sharded_state_dicts.append(current_block)
|
|
|
|
|
|
|
|
|
|
# If we only have one shard, we return it
|
|
|
|
|
if len(sharded_state_dicts) == 1:
|
|
|
|
|
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
|
|
|
|
|
|
|
|
|
# Otherwise, let's build the index
|
|
|
|
|
weight_map = {}
|
|
|
|
|
shards = {}
|
|
|
|
|
for idx, shard in enumerate(sharded_state_dicts):
|
|
|
|
|
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
|
|
|
|
shards[shard_file] = shard
|
|
|
|
|
for key in shard.keys():
|
|
|
|
|
weight_map[key] = shard_file
|
|
|
|
|
|
|
|
|
|
# Add the metadata
|
|
|
|
|
metadata = {"total_size": total_size}
|
|
|
|
|
index = {"metadata": metadata, "weight_map": weight_map}
|
|
|
|
|
return shards, index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_checkpoint_shard_files(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
index_filename,
|
|
|
|
|
cache_dir=None,
|
|
|
|
|
force_download=False,
|
|
|
|
|
proxies=None,
|
|
|
|
|
resume_download=False,
|
|
|
|
|
local_files_only=False,
|
|
|
|
|
use_auth_token=None,
|
|
|
|
|
user_agent=None,
|
|
|
|
|
revision=None,
|
|
|
|
|
mirror=None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
For a given model:
|
|
|
|
|
|
|
|
|
|
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
|
|
|
|
|
Hub
|
|
|
|
|
- returns the list of paths to all the shards, as well as some metadata.
|
|
|
|
|
|
|
|
|
|
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
|
|
|
|
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
|
|
|
|
"""
|
|
|
|
|
with open(index_filename, "r") as f:
|
|
|
|
|
index = json.loads(f.read())
|
|
|
|
|
|
|
|
|
|
shard_filenames = sorted(list(set(index["weight_map"].values())))
|
|
|
|
|
sharded_metadata = index["metadata"]
|
|
|
|
|
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
|
|
|
|
|
|
|
|
|
# First, let's deal with local folder.
|
|
|
|
|
if os.path.isdir(pretrained_model_name_or_path):
|
|
|
|
|
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
|
|
|
|
return shard_filenames, sharded_metadata
|
|
|
|
|
|
|
|
|
|
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
|
|
|
|
cached_filenames = []
|
|
|
|
|
for shard_filename in shard_filenames:
|
|
|
|
|
shard_url = hf_bucket_url(
|
|
|
|
|
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Load from URL
|
|
|
|
|
cached_filename = cached_path(
|
|
|
|
|
shard_url,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
|
|
|
|
# we don't have to catch them here.
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
raise EnvironmentError(
|
|
|
|
|
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
|
|
|
|
"required according to the checkpoint index."
|
|
|
|
|
)
|
|
|
|
|
except HTTPError:
|
|
|
|
|
raise EnvironmentError(
|
|
|
|
|
f"We couldn't connect to 'https://huggingface.co/' to load {shard_filename}. You should try again "
|
|
|
|
|
"after checking your internet connection."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cached_filenames.append(cached_filename)
|
|
|
|
|
|
|
|
|
|
return cached_filenames, sharded_metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
|
|
|
|
"""
|
|
|
|
|
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
return torch.load(checkpoint_file, map_location="cpu")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
try:
|
|
|
|
|
with open(checkpoint_file) as f:
|
|
|
|
|
if f.read().startswith("version"):
|
|
|
|
|
raise OSError(
|
|
|
|
|
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
|
|
|
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
|
|
|
|
"you cloned."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
|
|
|
|
"model. Make sure you have saved the model properly."
|
|
|
|
|
) from e
|
|
|
|
|
except (UnicodeDecodeError, ValueError):
|
|
|
|
|
raise OSError(
|
|
|
|
|
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
|
|
|
|
f"at '{checkpoint_file}'. "
|
|
|
|
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|
|
|
|
# Convert old format to new format if needed from a PyTorch state_dict
|
|
|
|
|
old_keys = []
|
|
|
|
|
new_keys = []
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
new_key = None
|
|
|
|
|
if "gamma" in key:
|
|
|
|
|
new_key = key.replace("gamma", "weight")
|
|
|
|
|
if "beta" in key:
|
|
|
|
|
new_key = key.replace("beta", "bias")
|
|
|
|
|
if new_key:
|
|
|
|
|
old_keys.append(key)
|
|
|
|
|
new_keys.append(new_key)
|
|
|
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
|
|
|
metadata = getattr(state_dict, "_metadata", None)
|
|
|
|
|
state_dict = state_dict.copy()
|
|
|
|
|
if metadata is not None:
|
|
|
|
|
state_dict._metadata = metadata
|
|
|
|
|
|
|
|
|
|
error_msgs = []
|
|
|
|
|
|
|
|
|
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
|
|
|
# so we need to apply the function recursively.
|
|
|
|
|
def load(module: nn.Module, prefix=""):
|
|
|
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
|
|
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
|
|
|
import deepspeed
|
|
|
|
|
|
|
|
|
|
# because zero3 puts placeholders in model params, this context
|
|
|
|
|
# manager gathers (unpartitions) the params of the current layer, then loads from
|
|
|
|
|
# the state dict and then re-partitions them again
|
|
|
|
|
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
|
|
|
module._load_from_state_dict(*args)
|
|
|
|
|
else:
|
|
|
|
|
module._load_from_state_dict(*args)
|
|
|
|
|
|
|
|
|
|
for name, child in module._modules.items():
|
|
|
|
|
if child is not None:
|
|
|
|
|
load(child, prefix + name + ".")
|
|
|
|
|
|
|
|
|
|
load(model_to_load, prefix=start_prefix)
|
|
|
|
|
|
|
|
|
|
return error_msgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModuleUtilsMixin:
|
|
|
|
|
"""
|
|
|
|
|
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
|
|
|
|
@@ -1004,6 +1274,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
state_dict: Optional[dict] = None,
|
|
|
|
|
save_function: Callable = torch.save,
|
|
|
|
|
push_to_hub: bool = False,
|
|
|
|
|
max_shard_size: Union[int, str] = "10GB",
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
@@ -1035,6 +1306,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
|
|
|
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
|
|
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
|
|
|
|
|
|
|
|
<Tip warning={true}>
|
|
|
|
|
|
|
|
|
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
|
|
|
|
which will be bigger than `max_shard_size`.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
kwargs:
|
|
|
|
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
|
|
|
|
"""
|
|
|
|
|
@@ -1078,11 +1360,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
if ignore_key in state_dict.keys():
|
|
|
|
|
del state_dict[ignore_key]
|
|
|
|
|
|
|
|
|
|
# If we save using the predefined names, we can load using `from_pretrained`
|
|
|
|
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
|
|
|
|
save_function(state_dict, output_model_file)
|
|
|
|
|
# Shard the model if it is too big.
|
|
|
|
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Model weights saved in {output_model_file}")
|
|
|
|
|
# Clean the folder from a previous save
|
|
|
|
|
for filename in os.listdir(save_directory):
|
|
|
|
|
full_filename = os.path.join(save_directory, filename)
|
|
|
|
|
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
|
|
|
|
|
os.remove(full_filename)
|
|
|
|
|
|
|
|
|
|
# Save the model
|
|
|
|
|
for shard_file, shard in shards.items():
|
|
|
|
|
save_function(shard, os.path.join(save_directory, shard_file))
|
|
|
|
|
|
|
|
|
|
if index is None:
|
|
|
|
|
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
|
|
|
|
else:
|
|
|
|
|
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
|
|
|
|
|
# Save the index as well
|
|
|
|
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
|
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
|
|
|
f.write(content)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
|
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
|
|
|
f"index located at {save_index_file}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
|
|
|
|
@@ -1293,6 +1596,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
else:
|
|
|
|
|
model_kwargs = kwargs
|
|
|
|
|
|
|
|
|
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
|
|
|
|
# index of the files.
|
|
|
|
|
is_sharded = False
|
|
|
|
|
sharded_metadata = None
|
|
|
|
|
# Load model
|
|
|
|
|
if pretrained_model_name_or_path is not None:
|
|
|
|
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
|
|
|
@@ -1309,6 +1616,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
|
|
|
|
# Load from a PyTorch checkpoint
|
|
|
|
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
|
|
|
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
|
|
|
|
# Load from a sharded PyTorch checkpoint
|
|
|
|
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
|
|
|
|
is_sharded = True
|
|
|
|
|
# At this stage we don't have a weight file so we will raise an error.
|
|
|
|
|
elif os.path.isfile(
|
|
|
|
|
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
|
|
|
|
@@ -1382,6 +1693,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
)
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
if filename == WEIGHTS_NAME:
|
|
|
|
|
try:
|
|
|
|
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
|
|
|
archive_file = hf_bucket_url(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
filename=WEIGHTS_INDEX_NAME,
|
|
|
|
|
revision=revision,
|
|
|
|
|
mirror=mirror,
|
|
|
|
|
)
|
|
|
|
|
resolved_archive_file = cached_path(
|
|
|
|
|
archive_file,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
is_sharded = True
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
|
|
|
|
# message.
|
|
|
|
|
has_file_kwargs = {
|
|
|
|
|
"revision": revision,
|
|
|
|
|
"mirror": mirror,
|
|
|
|
|
@@ -1439,29 +1772,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
else:
|
|
|
|
|
resolved_archive_file = None
|
|
|
|
|
|
|
|
|
|
# load pt weights early so that we know which dtype to init the model under
|
|
|
|
|
if from_pt:
|
|
|
|
|
if state_dict is None:
|
|
|
|
|
try:
|
|
|
|
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
try:
|
|
|
|
|
with open(resolved_archive_file) as f:
|
|
|
|
|
if f.read().startswith("version"):
|
|
|
|
|
raise OSError(
|
|
|
|
|
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
|
|
|
|
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
|
|
|
|
"you cloned."
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError from e
|
|
|
|
|
except (UnicodeDecodeError, ValueError):
|
|
|
|
|
raise OSError(
|
|
|
|
|
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
|
|
|
|
f"at '{resolved_archive_file}'. "
|
|
|
|
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
|
|
|
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
|
|
|
|
if is_sharded:
|
|
|
|
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
|
|
|
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
resolved_archive_file,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
revision=revision,
|
|
|
|
|
mirror=mirror,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# load pt weights early so that we know which dtype to init the model under
|
|
|
|
|
if from_pt:
|
|
|
|
|
if not is_sharded:
|
|
|
|
|
# Time to load the checkpoint
|
|
|
|
|
state_dict = load_state_dict(resolved_archive_file)
|
|
|
|
|
# set dtype to instantiate the model under:
|
|
|
|
|
# 1. If torch_dtype is not None, we use that dtype
|
|
|
|
|
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
|
|
|
|
@@ -1471,7 +1803,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
if torch_dtype is not None:
|
|
|
|
|
if isinstance(torch_dtype, str):
|
|
|
|
|
if torch_dtype == "auto":
|
|
|
|
|
if is_sharded and "dtype" in sharded_metadata:
|
|
|
|
|
torch_dtype = sharded_metadata["dtype"]
|
|
|
|
|
elif not is_sharded:
|
|
|
|
|
torch_dtype = next(iter(state_dict.values())).dtype
|
|
|
|
|
else:
|
|
|
|
|
one_state_dict = load_state_dict(resolved_archive_file)
|
|
|
|
|
torch_dtype = next(iter(one_state_dict.values())).dtype
|
|
|
|
|
del one_state_dict # free CPU memory
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
|
|
|
|
@@ -1480,6 +1819,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
|
|
|
|
|
if low_cpu_mem_usage:
|
|
|
|
|
# save the keys
|
|
|
|
|
if is_sharded:
|
|
|
|
|
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
|
|
|
|
else:
|
|
|
|
|
state_dict = load_state_dict(resolved_archive_file)
|
|
|
|
|
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
|
|
|
|
del state_dict # free CPU memory - will reload again later
|
|
|
|
|
|
|
|
|
|
@@ -1534,13 +1877,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
elif from_pt:
|
|
|
|
|
|
|
|
|
|
if low_cpu_mem_usage:
|
|
|
|
|
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
|
|
|
|
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
|
|
|
|
else:
|
|
|
|
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
|
|
|
|
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
|
|
|
|
model,
|
|
|
|
|
state_dict,
|
|
|
|
|
resolved_archive_file,
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
|
|
|
sharded_metadata=sharded_metadata,
|
|
|
|
|
_fast_init=_fast_init,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -1562,31 +1907,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _load_state_dict_into_model(
|
|
|
|
|
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
|
|
|
|
|
def _load_pretrained_model(
|
|
|
|
|
cls,
|
|
|
|
|
model,
|
|
|
|
|
state_dict,
|
|
|
|
|
resolved_archive_file,
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
ignore_mismatched_sizes=False,
|
|
|
|
|
sharded_metadata=None,
|
|
|
|
|
_fast_init=True,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
# Convert old format to new format if needed from a PyTorch state_dict
|
|
|
|
|
old_keys = []
|
|
|
|
|
new_keys = []
|
|
|
|
|
for key in state_dict.keys():
|
|
|
|
|
new_key = None
|
|
|
|
|
if "gamma" in key:
|
|
|
|
|
new_key = key.replace("gamma", "weight")
|
|
|
|
|
if "beta" in key:
|
|
|
|
|
new_key = key.replace("beta", "bias")
|
|
|
|
|
if new_key:
|
|
|
|
|
old_keys.append(key)
|
|
|
|
|
new_keys.append(new_key)
|
|
|
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
|
|
|
|
# Retrieve missing & unexpected_keys
|
|
|
|
|
model_state_dict = model.state_dict()
|
|
|
|
|
expected_keys = list(model_state_dict.keys())
|
|
|
|
|
loaded_keys = list(state_dict.keys())
|
|
|
|
|
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
|
|
|
|
|
prefix = model.base_model_prefix
|
|
|
|
|
|
|
|
|
|
def _fix_key(key):
|
|
|
|
|
if "beta" in key:
|
|
|
|
|
return key.replace("beta", "bias")
|
|
|
|
|
if "gamma" in key:
|
|
|
|
|
return key.replace("gamma", "weight")
|
|
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
loaded_keys = [_fix_key(key) for key in loaded_keys]
|
|
|
|
|
|
|
|
|
|
if len(prefix) > 0:
|
|
|
|
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
|
|
|
|
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
|
|
|
|
@@ -1608,6 +1953,69 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
|
|
|
|
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
|
|
|
|
|
|
|
|
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
|
|
|
|
# the user.
|
|
|
|
|
if cls._keys_to_ignore_on_load_missing is not None:
|
|
|
|
|
for pat in cls._keys_to_ignore_on_load_missing:
|
|
|
|
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
|
|
|
|
|
|
|
|
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
|
|
|
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
|
|
|
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
|
|
|
|
|
|
|
|
if _fast_init:
|
|
|
|
|
# retrieve unintialized modules and initialize
|
|
|
|
|
uninitialized_modules = model.retrieve_modules_from_names(
|
|
|
|
|
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
|
|
|
|
|
)
|
|
|
|
|
for module in uninitialized_modules:
|
|
|
|
|
model._init_weights(module)
|
|
|
|
|
|
|
|
|
|
# Make sure we are able to load base models as well as derived models (with heads)
|
|
|
|
|
start_prefix = ""
|
|
|
|
|
model_to_load = model
|
|
|
|
|
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
|
|
|
|
start_prefix = cls.base_model_prefix + "."
|
|
|
|
|
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
|
|
|
|
model_to_load = getattr(model, cls.base_model_prefix)
|
|
|
|
|
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
|
|
|
|
|
"properly saved?"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if state_dict is not None:
|
|
|
|
|
# Whole checkpoint
|
|
|
|
|
mismatched_keys = []
|
|
|
|
|
if ignore_mismatched_sizes:
|
|
|
|
|
for checkpoint_key in loaded_keys:
|
|
|
|
|
model_key = checkpoint_key
|
|
|
|
|
if remove_prefix_from_model:
|
|
|
|
|
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
|
|
|
|
model_key = f"{prefix}.{checkpoint_key}"
|
|
|
|
|
elif add_prefix_to_model:
|
|
|
|
|
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
|
|
|
|
model_key = ".".join(checkpoint_key.split(".")[1:])
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
model_key in model_state_dict
|
|
|
|
|
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
|
|
|
|
):
|
|
|
|
|
mismatched_keys.append(
|
|
|
|
|
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
|
|
|
|
)
|
|
|
|
|
del state_dict[checkpoint_key]
|
|
|
|
|
|
|
|
|
|
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
|
|
|
|
else:
|
|
|
|
|
# Sharded checkpoint
|
|
|
|
|
# This should always be a list but, just to be sure.
|
|
|
|
|
if not isinstance(resolved_archive_file, list):
|
|
|
|
|
resolved_archive_file = [resolved_archive_file]
|
|
|
|
|
|
|
|
|
|
error_msgs = []
|
|
|
|
|
for shard_file in resolved_archive_file:
|
|
|
|
|
state_dict = load_state_dict(shard_file)
|
|
|
|
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
|
|
|
# matching the weights in the model.
|
|
|
|
|
mismatched_keys = []
|
|
|
|
|
@@ -1630,67 +2038,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
)
|
|
|
|
|
del state_dict[checkpoint_key]
|
|
|
|
|
|
|
|
|
|
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
|
|
|
|
# the user.
|
|
|
|
|
if cls._keys_to_ignore_on_load_missing is not None:
|
|
|
|
|
for pat in cls._keys_to_ignore_on_load_missing:
|
|
|
|
|
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
|
|
|
|
|
|
|
|
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
|
|
|
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
|
|
|
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
|
|
|
|
|
|
|
|
if _fast_init:
|
|
|
|
|
# retrieve unintialized modules and initialize
|
|
|
|
|
uninitialized_modules = model.retrieve_modules_from_names(
|
|
|
|
|
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
|
|
|
|
|
)
|
|
|
|
|
for module in uninitialized_modules:
|
|
|
|
|
model._init_weights(module)
|
|
|
|
|
|
|
|
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
|
|
|
metadata = getattr(state_dict, "_metadata", None)
|
|
|
|
|
state_dict = state_dict.copy()
|
|
|
|
|
if metadata is not None:
|
|
|
|
|
state_dict._metadata = metadata
|
|
|
|
|
|
|
|
|
|
error_msgs = []
|
|
|
|
|
|
|
|
|
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
|
|
|
|
# so we need to apply the function recursively.
|
|
|
|
|
def load(module: nn.Module, prefix=""):
|
|
|
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
|
|
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
|
|
|
import deepspeed
|
|
|
|
|
|
|
|
|
|
# because zero3 puts placeholders in model params, this context
|
|
|
|
|
# manager gathers (unpartitions) the params of the current layer, then loads from
|
|
|
|
|
# the state dict and then re-partitions them again
|
|
|
|
|
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
|
|
|
|
if torch.distributed.get_rank() == 0:
|
|
|
|
|
module._load_from_state_dict(*args)
|
|
|
|
|
else:
|
|
|
|
|
module._load_from_state_dict(*args)
|
|
|
|
|
|
|
|
|
|
for name, child in module._modules.items():
|
|
|
|
|
if child is not None:
|
|
|
|
|
load(child, prefix + name + ".")
|
|
|
|
|
|
|
|
|
|
# Make sure we are able to load base models as well as derived models (with heads)
|
|
|
|
|
start_prefix = ""
|
|
|
|
|
model_to_load = model
|
|
|
|
|
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
|
|
|
|
start_prefix = cls.base_model_prefix + "."
|
|
|
|
|
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
|
|
|
|
model_to_load = getattr(model, cls.base_model_prefix)
|
|
|
|
|
if any(key in expected_keys_not_prefixed for key in loaded_keys):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The state dictionary of the model you are training to load is corrupted. Are you sure it was "
|
|
|
|
|
"properly saved?"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
load(model_to_load, prefix=start_prefix)
|
|
|
|
|
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
|
|
|
|
|
|
|
|
|
if len(error_msgs) > 0:
|
|
|
|
|
error_msg = "\n\t".join(error_msgs)
|
|
|
|
|
@@ -1755,7 +2103,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
return retrieved_modules
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
|
|
|
|
def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
|
|
|
|
"""
|
|
|
|
|
This is an experimental function that loads the model using ~1.x model size CPU memory
|
|
|
|
|
|
|
|
|
|
@@ -1772,7 +2120,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
|
|
|
|
|
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
require_version_core("torch>=1.9")
|
|
|
|
|
if is_deepspeed_zero3_enabled():
|
|
|
|
|
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
|
|
|
|
|
@@ -1806,7 +2153,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
new_val = new_val.to("meta")
|
|
|
|
|
setattr(submodule, param_name, new_val)
|
|
|
|
|
|
|
|
|
|
# only now can load state_dict
|
|
|
|
|
# only now can load state_dict(s)
|
|
|
|
|
if not isinstance(resolved_archive_file, list):
|
|
|
|
|
resolved_archive_file = [resolved_archive_file]
|
|
|
|
|
|
|
|
|
|
for archive_file in resolved_archive_file:
|
|
|
|
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
# materialize state_dict entries one by one on CPU
|
|
|
|
|
@@ -1846,12 +2197,109 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|
|
|
|
|
|
|
|
|
cls._auto_class = auto_class
|
|
|
|
|
|
|
|
|
|
def push_to_hub(
|
|
|
|
|
self,
|
|
|
|
|
repo_path_or_name: Optional[str] = None,
|
|
|
|
|
repo_url: Optional[str] = None,
|
|
|
|
|
use_temp_dir: bool = False,
|
|
|
|
|
commit_message: str = "add model",
|
|
|
|
|
organization: Optional[str] = None,
|
|
|
|
|
private: Optional[bool] = None,
|
|
|
|
|
use_auth_token: Optional[Union[bool, str]] = None,
|
|
|
|
|
max_shard_size: Union[int, str] = "10GB",
|
|
|
|
|
**model_card_kwargs
|
|
|
|
|
) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
|
|
|
|
|
|
|
|
|
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
|
|
|
|
|
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
|
|
|
|
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
|
|
|
|
|
object="model", object_class="AutoModel", object_files="model checkpoint"
|
|
|
|
|
Parameters:
|
|
|
|
|
repo_path_or_name (`str`, *optional*):
|
|
|
|
|
Can either be a repository name for your model in the Hub or a path to a local folder (in which case
|
|
|
|
|
the repository will have the name of that local folder). If not specified, will default to the name
|
|
|
|
|
given by `repo_url` and a local directory with that name will be created.
|
|
|
|
|
repo_url (`str`, *optional*):
|
|
|
|
|
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
|
|
|
|
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
|
|
|
|
|
use_temp_dir (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
|
|
|
|
|
current working directory. This will slow things down if you are making changes in an existing repo
|
|
|
|
|
since you will need to clone the repo before every push.
|
|
|
|
|
commit_message (`str`, *optional*, defaults to `"add model"`):
|
|
|
|
|
Message to commit while pushing.
|
|
|
|
|
organization (`str`, *optional*):
|
|
|
|
|
Organization in which you want to push your {object} (you must be a member of this organization).
|
|
|
|
|
private (`bool`, *optional*):
|
|
|
|
|
Whether or not the repository created should be private (requires a paying subscription).
|
|
|
|
|
use_auth_token (`bool` or `str`, *optional*):
|
|
|
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
|
|
|
|
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
|
|
|
|
`repo_url` is not specified.
|
|
|
|
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
|
|
|
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
|
|
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
|
|
|
|
|
|
|
|
<Tip warning={true}>
|
|
|
|
|
|
|
|
|
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
|
|
|
|
which will be bigger than `max_shard_size`.
|
|
|
|
|
|
|
|
|
|
</Tip>
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
`str`: The url of the commit of your {object} in the given repository.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
from transformers import AutoModel
|
|
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained("bert-base-cased")
|
|
|
|
|
|
|
|
|
|
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
|
|
|
|
# *my-finetuned-bert* folder.
|
|
|
|
|
model.push_to_hub("my-finetuned-bert")
|
|
|
|
|
|
|
|
|
|
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
|
|
|
|
|
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
|
|
|
|
|
|
|
|
|
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
|
|
|
|
|
# *my-finetuned-bert* folder.
|
|
|
|
|
model.push_to_hub("my-finetuned-bert", organization="huggingface")
|
|
|
|
|
|
|
|
|
|
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
|
|
|
|
|
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
if use_temp_dir:
|
|
|
|
|
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
|
|
|
|
if repo_url is None:
|
|
|
|
|
if use_auth_token is None:
|
|
|
|
|
use_auth_token = True
|
|
|
|
|
repo_name = Path(repo_path_or_name).name
|
|
|
|
|
repo_url = self._get_repo_url_from_name(
|
|
|
|
|
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
|
|
|
|
)
|
|
|
|
|
repo_path_or_name = tempfile.mkdtemp()
|
|
|
|
|
|
|
|
|
|
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
|
|
|
|
repo = self._create_or_get_repo(
|
|
|
|
|
repo_path_or_name=repo_path_or_name,
|
|
|
|
|
repo_url=repo_url,
|
|
|
|
|
organization=organization,
|
|
|
|
|
private=private,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
)
|
|
|
|
|
# Save the files in the cloned repo
|
|
|
|
|
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
|
|
|
|
|
|
|
|
|
# Commit and push!
|
|
|
|
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
|
|
|
|
|
|
|
|
|
# Clean up! Clean up! Everybody everywhere!
|
|
|
|
|
if use_temp_dir:
|
|
|
|
|
shutil.rmtree(repo_path_or_name)
|
|
|
|
|
|
|
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conv1D(nn.Module):
|
|
|
|
|
|