Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)

* fix all main files

* fix test files

* oups forgot modular

* add link

* update message
This commit is contained in:
Cyril Vallez
2025-04-25 16:57:09 +02:00
committed by GitHub
parent eefc86aa31
commit 0cfbf9c95b
24 changed files with 88 additions and 9 deletions

View File

@@ -113,6 +113,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
check_torch_load_is_safe,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
@@ -646,6 +647,7 @@ class TrainerIntegrationCommon:
else:
best_model = RegressionModel()
if not safe_weights:
check_torch_load_is_safe()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True)
else:
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
@@ -678,6 +680,7 @@ class TrainerIntegrationCommon:
loader = safetensors.torch.load_file
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
else:
check_torch_load_is_safe()
loader = torch.load
weights_file = os.path.join(folder, WEIGHTS_NAME)