[shieldgemma] fix checkpoint loading (#39348)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a1ad9197c5
commit
66cd995618
@@ -22,6 +22,7 @@ from PIL import Image
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_read_token,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -29,14 +30,12 @@ from transformers.testing_utils import (
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
# @require_read_token
|
||||
@require_read_token
|
||||
class ShieldGemma2IntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
@@ -49,11 +48,9 @@ class ShieldGemma2IntegrationTest(unittest.TestCase):
|
||||
response = requests.get(url)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
|
||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, load_in_4bit=True)
|
||||
|
||||
inputs = processor(images=[image]).to(torch_device)
|
||||
inputs = processor(images=[image], return_tensors="pt").to(torch_device)
|
||||
output = model(**inputs)
|
||||
self.assertEqual(len(output.probabilities), 3)
|
||||
for element in output.probabilities:
|
||||
|
||||
Reference in New Issue
Block a user