diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a2913f2296..7153257a10 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -240,6 +240,7 @@ VLMS = [ "mistral3", "mllama", "paligemma", + "shieldgemma2", "qwen2vl", "qwen2_5_vl", "videollava", diff --git a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py index a164e61420..a77ea28a22 100644 --- a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py +++ b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -45,6 +45,12 @@ class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWith @auto_docstring class ShieldGemma2ForImageClassification(PreTrainedModel): config_class = ShieldGemma2Config + _checkpoint_conversion_mapping = { + "model.language_model.model": "model.model.language_model", + "model.vision_tower": "model.model.vision_tower", + "model.multi_modal_projector": "model.model.multi_modal_projector", + "model.language_model.lm_head": "model.lm_head", + } def __init__(self, config: ShieldGemma2Config): super().__init__(config=config) diff --git a/tests/models/shieldgemma2/test_modeling_shieldgemma2.py b/tests/models/shieldgemma2/test_modeling_shieldgemma2.py index de41ad0fe0..4683c76bc6 100644 --- a/tests/models/shieldgemma2/test_modeling_shieldgemma2.py +++ b/tests/models/shieldgemma2/test_modeling_shieldgemma2.py @@ -22,6 +22,7 @@ from PIL import Image from transformers import is_torch_available from transformers.testing_utils import ( cleanup, + require_read_token, require_torch_accelerator, slow, torch_device, @@ -29,14 +30,12 @@ from transformers.testing_utils import ( if is_torch_available(): - import torch - from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor @slow @require_torch_accelerator -# @require_read_token +@require_read_token class ShieldGemma2IntegrationTest(unittest.TestCase): def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -49,11 +48,9 @@ class ShieldGemma2IntegrationTest(unittest.TestCase): response = requests.get(url) image = Image.open(BytesIO(response.content)) - model = ShieldGemma2ForImageClassification.from_pretrained(model_id, torch_dtype=torch.bfloat16).to( - torch_device - ) + model = ShieldGemma2ForImageClassification.from_pretrained(model_id, load_in_4bit=True) - inputs = processor(images=[image]).to(torch_device) + inputs = processor(images=[image], return_tensors="pt").to(torch_device) output = model(**inputs) self.assertEqual(len(output.probabilities), 3) for element in output.probabilities: