[model loading] don't gc.collect() if only 1 shard is used (#36721)

* don't gc collect if 1 shard is used

* delete state dict anyways
This commit is contained in:
Joao Gante
2025-03-14 12:56:56 +00:00
committed by GitHub
parent 8cb522b419
commit 3bd1a0ddf1

View File

@@ -4831,6 +4831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
error_msgs = [] error_msgs = []
mismatched_keys = [] mismatched_keys = []
has_multiple_shards = len(checkpoint_files) > 1
# Iterate on all the shards to load the weights # Iterate on all the shards to load the weights
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights # Skip the load for shards that only contain disk-offloaded weights
@@ -4849,7 +4850,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
): ):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
# If shard_file is""", we use the existing state_dict instead of loading it # If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "": if shard_file != "":
state_dict = load_state_dict( state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
@@ -4895,9 +4896,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
# force memory release
del state_dict del state_dict
gc.collect() # force memory release if loading multiple shards
if has_multiple_shards:
gc.collect()
# Adjust offloaded weights name and save if needed # Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0: if disk_offload_index is not None and len(disk_offload_index) > 0: