Set weights_only in torch.load (#36991)

This commit is contained in:
cyyever
2025-03-27 22:55:50 +08:00
committed by GitHub
parent de77f5b1ec
commit 41a0e58e5b
28 changed files with 64 additions and 78 deletions

View File

@@ -504,8 +504,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
weights_only_kwarg = {"weights_only": True}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
@@ -598,11 +597,10 @@ def load_state_dict(
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
**weights_only_kwarg,
weights_only=weights_only,
**extra_args,
)
except Exception as e:
@@ -1216,7 +1214,7 @@ def _get_torch_dtype(
weights_only: bool,
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
infered dtype. We do the following:
inferred dtype. We do the following:
1. If torch_dtype is not None, we use that dtype
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype