[model loading] don't gc.collect() if only 1 shard is used (#36721)
* don't gc collect if 1 shard is used * delete state dict anyways
This commit is contained in:
@@ -4831,6 +4831,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
|
has_multiple_shards = len(checkpoint_files) > 1
|
||||||
# Iterate on all the shards to load the weights
|
# Iterate on all the shards to load the weights
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
# Skip the load for shards that only contain disk-offloaded weights
|
# Skip the load for shards that only contain disk-offloaded weights
|
||||||
@@ -4849,7 +4850,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
):
|
):
|
||||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||||
|
|
||||||
# If shard_file is""", we use the existing state_dict instead of loading it
|
# If shard_file is "", we use the existing state_dict instead of loading it
|
||||||
if shard_file != "":
|
if shard_file != "":
|
||||||
state_dict = load_state_dict(
|
state_dict = load_state_dict(
|
||||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||||
@@ -4895,9 +4896,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
|
||||||
|
|
||||||
# force memory release
|
|
||||||
del state_dict
|
del state_dict
|
||||||
gc.collect()
|
# force memory release if loading multiple shards
|
||||||
|
if has_multiple_shards:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# Adjust offloaded weights name and save if needed
|
# Adjust offloaded weights name and save if needed
|
||||||
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
if disk_offload_index is not None and len(disk_offload_index) > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user