Set weights_only in torch.load (#36991)
This commit is contained in:
@@ -504,8 +504,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
weights_only_kwarg = {"weights_only": True}
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@@ -598,11 +597,10 @@ def load_state_dict(
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
weights_only_kwarg = {"weights_only": weights_only}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location=map_location,
|
||||
**weights_only_kwarg,
|
||||
weights_only=weights_only,
|
||||
**extra_args,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1216,7 +1214,7 @@ def _get_torch_dtype(
|
||||
weights_only: bool,
|
||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||
infered dtype. We do the following:
|
||||
inferred dtype. We do the following:
|
||||
1. If torch_dtype is not None, we use that dtype
|
||||
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||
|
||||
Reference in New Issue
Block a user