Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
@@ -74,6 +74,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
check_torch_load_is_safe,
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
@@ -739,6 +740,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(shard_file, weights_only=True)
|
||||
self.assertEqual(len(state_dict), 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user