Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)

* fix all main files

* fix test files

* oups forgot modular

* add link

* update message
This commit is contained in:
Cyril Vallez
2025-04-25 16:57:09 +02:00
committed by GitHub
parent eefc86aa31
commit 0cfbf9c95b
24 changed files with 88 additions and 9 deletions

View File

@@ -74,6 +74,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
check_torch_load_is_safe,
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
@@ -739,6 +740,7 @@ class ModelUtilsTest(TestCasePlus):
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
check_torch_load_is_safe()
state_dict = torch.load(shard_file, weights_only=True)
self.assertEqual(len(state_dict), 1)