[gemma3] support sequence classification task (#39465)

* add seq clf class

* fix docs and add in auto-map

* skip tests

* optional pixels
This commit is contained in:
Raushan Turganbay
2025-07-21 11:03:20 +02:00
committed by GitHub
parent 34133d0a79
commit e42681b48b
5 changed files with 196 additions and 2 deletions

View File

@@ -53,6 +53,7 @@ if is_torch_available():
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
Gemma3Model,
Gemma3Processor,
Gemma3TextModel,
@@ -246,6 +247,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
(
Gemma3Model,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
)
if is_torch_available()
else ()
@@ -348,6 +350,14 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_initialization(self):
pass
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
def test_load_with_mismatched_shapes(self):
pass
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
def test_mismatched_shapes_have_properly_initialized_weights(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that