Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)

* fix all main files

* fix test files

* oups forgot modular

* add link

* update message
This commit is contained in:
Cyril Vallez
2025-04-25 16:57:09 +02:00
committed by GitHub
parent eefc86aa31
commit 0cfbf9c95b
24 changed files with 88 additions and 9 deletions

View File

@@ -94,6 +94,7 @@ from .utils import (
ModelOutput,
PushToHubMixin,
cached_file,
check_torch_load_is_safe,
copy_func,
download_url,
extract_commit_hash,
@@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
if load_safe:
loader = safe_load_file
else:
check_torch_load_is_safe()
loader = partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
@@ -490,6 +495,7 @@ def load_state_dict(
"""
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
"""
# Use safetensors if possible
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
@@ -512,6 +518,9 @@ def load_state_dict(
state_dict[k] = f.get_tensor(k)
return state_dict
# Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
if weights_only:
check_torch_load_is_safe()
try:
if map_location is None:
if (