[shieldgemma] fix checkpoint loading (#39348)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a1ad9197c5
commit
66cd995618
@@ -240,6 +240,7 @@ VLMS = [
|
|||||||
"mistral3",
|
"mistral3",
|
||||||
"mllama",
|
"mllama",
|
||||||
"paligemma",
|
"paligemma",
|
||||||
|
"shieldgemma2",
|
||||||
"qwen2vl",
|
"qwen2vl",
|
||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
"videollava",
|
"videollava",
|
||||||
|
|||||||
@@ -45,6 +45,12 @@ class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWith
|
|||||||
@auto_docstring
|
@auto_docstring
|
||||||
class ShieldGemma2ForImageClassification(PreTrainedModel):
|
class ShieldGemma2ForImageClassification(PreTrainedModel):
|
||||||
config_class = ShieldGemma2Config
|
config_class = ShieldGemma2Config
|
||||||
|
_checkpoint_conversion_mapping = {
|
||||||
|
"model.language_model.model": "model.model.language_model",
|
||||||
|
"model.vision_tower": "model.model.vision_tower",
|
||||||
|
"model.multi_modal_projector": "model.model.multi_modal_projector",
|
||||||
|
"model.language_model.lm_head": "model.lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: ShieldGemma2Config):
|
def __init__(self, config: ShieldGemma2Config):
|
||||||
super().__init__(config=config)
|
super().__init__(config=config)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from PIL import Image
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
cleanup,
|
cleanup,
|
||||||
|
require_read_token,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@@ -29,14 +30,12 @@ from transformers.testing_utils import (
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
|
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
# @require_read_token
|
@require_read_token
|
||||||
class ShieldGemma2IntegrationTest(unittest.TestCase):
|
class ShieldGemma2IntegrationTest(unittest.TestCase):
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cleanup(torch_device, gc_collect=True)
|
cleanup(torch_device, gc_collect=True)
|
||||||
@@ -49,11 +48,9 @@ class ShieldGemma2IntegrationTest(unittest.TestCase):
|
|||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
image = Image.open(BytesIO(response.content))
|
image = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(
|
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, load_in_4bit=True)
|
||||||
torch_device
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs = processor(images=[image]).to(torch_device)
|
inputs = processor(images=[image], return_tensors="pt").to(torch_device)
|
||||||
output = model(**inputs)
|
output = model(**inputs)
|
||||||
self.assertEqual(len(output.probabilities), 3)
|
self.assertEqual(len(output.probabilities), 3)
|
||||||
for element in output.probabilities:
|
for element in output.probabilities:
|
||||||
|
|||||||
Reference in New Issue
Block a user