Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
@@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -366,6 +367,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
filename="llava_1_6_input_ids.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||
# remove image token indices because HF impl expands image tokens `image_seq_length` times
|
||||
@@ -378,6 +380,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
filename="llava_1_6_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user