[shieldgemma] fix checkpoint loading (#39348)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay
2025-07-14 11:34:58 +05:00
committed by GitHub
parent a1ad9197c5
commit 66cd995618
3 changed files with 11 additions and 7 deletions

View File

@@ -22,6 +22,7 @@ from PIL import Image
from transformers import is_torch_available
from transformers.testing_utils import (
cleanup,
require_read_token,
require_torch_accelerator,
slow,
torch_device,
@@ -29,14 +30,12 @@ from transformers.testing_utils import (
if is_torch_available():
import torch
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
@slow
@require_torch_accelerator
# @require_read_token
@require_read_token
class ShieldGemma2IntegrationTest(unittest.TestCase):
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@@ -49,11 +48,9 @@ class ShieldGemma2IntegrationTest(unittest.TestCase):
response = requests.get(url)
image = Image.open(BytesIO(response.content))
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(
torch_device
)
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, load_in_4bit=True)
inputs = processor(images=[image]).to(torch_device)
inputs = processor(images=[image], return_tensors="pt").to(torch_device)
output = model(**inputs)
self.assertEqual(len(output.probabilities), 3)
for element in output.probabilities: