Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
@@ -113,6 +113,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
check_torch_load_is_safe,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
@@ -646,6 +647,7 @@ class TrainerIntegrationCommon:
|
||||
else:
|
||||
best_model = RegressionModel()
|
||||
if not safe_weights:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True)
|
||||
else:
|
||||
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
|
||||
@@ -678,6 +680,7 @@ class TrainerIntegrationCommon:
|
||||
loader = safetensors.torch.load_file
|
||||
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = torch.load
|
||||
weights_file = os.path.join(folder, WEIGHTS_NAME)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user