diff --git a/setup.py b/setup.py index 3d6c78fd9a..4edffc724e 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ _deps = [ "fugashi>=1.0", "GitPython<3.1.19", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.23.0,<1.0", + "huggingface-hub>=0.23.2,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 29c916aff6..3148d0f339 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -24,7 +24,7 @@ deps = { "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.23.0,<1.0", + "huggingface-hub": "huggingface-hub>=0.23.2,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 263b24d815..1b0456eff9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from zipfile import is_zipfile import torch +from huggingface_hub import split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn from torch.nn import CrossEntropyLoss, Identity @@ -362,6 +363,10 @@ def shard_checkpoint( weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): The name of the model save file. """ + logger.warning( + "Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using " + "split_torch_state_dict_into_shards from huggingface_hub library" + ) max_shard_size = convert_file_size_to_int(max_shard_size) sharded_state_dicts = [{}] @@ -2618,7 +2623,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } # Clean the folder from a previous save for filename in os.listdir(save_directory): @@ -2634,14 +2649,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if ( filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) - and filename not in shards.keys() + and filename not in state_dict_split.filename_to_tensors.keys() and is_main_process and reg.fullmatch(filename_no_suffix) is not None ): os.remove(full_filename) # Save the model - for shard_file, shard in shards.items(): + for shard_file, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} # remake shard with onloaded parameters if necessary if module_map: if accelerate_version < version.parse("0.31"): @@ -2680,7 +2696,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix f.write(content) logger.info( f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 8a2db45d9b..4c76071574 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -669,7 +669,7 @@ class ModelUtilsTest(TestCasePlus): with tempfile.TemporaryDirectory() as tmp_dir: # We use the same folder for various sizes to make sure a new save erases the old checkpoint. - for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]: + for max_size in ["50kB", "100kB", "200kB"]: model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False) # Get each shard file and its size @@ -686,10 +686,7 @@ class ModelUtilsTest(TestCasePlus): # Check a file is bigger than max_size only when it has a single weight for shard_file, size in shard_to_size.items(): - if max_size.endswith("kiB"): - max_size_int = int(max_size[:-3]) * 2**10 - else: - max_size_int = int(max_size[:-2]) * 10**3 + max_size_int = int(max_size[:-2]) * 10**3 # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than # the size asked for (since we count parameters) if size >= max_size_int + 50000: