Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
@@ -94,6 +94,7 @@ from .utils import (
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
cached_file,
|
||||
check_torch_load_is_safe,
|
||||
copy_func,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
@@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||
if load_safe:
|
||||
loader = safe_load_file
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@@ -490,6 +495,7 @@ def load_state_dict(
|
||||
"""
|
||||
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
||||
"""
|
||||
# Use safetensors if possible
|
||||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
@@ -512,6 +518,9 @@ def load_state_dict(
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
# Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
|
||||
if weights_only:
|
||||
check_torch_load_is_safe()
|
||||
try:
|
||||
if map_location is None:
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user