Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
@@ -27,6 +27,7 @@ from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@@ -451,6 +452,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
|
||||
# TODO: Make repo public
|
||||
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user