Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)

* fix all main files

* fix test files

* oups forgot modular

* add link

* update message
This commit is contained in:
Cyril Vallez
2025-04-25 16:57:09 +02:00
committed by GitHub
parent eefc86aa31
commit 0cfbf9c95b
24 changed files with 88 additions and 9 deletions

View File

@@ -23,6 +23,7 @@ from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@@ -302,6 +303,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"):
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch