[gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741)

This commit is contained in:
Joao Gante
2025-03-19 15:04:19 +00:00
committed by GitHub
parent b11050d6a2
commit 7c233980f4
3 changed files with 34 additions and 0 deletions

View File

@@ -444,6 +444,11 @@ class _BaseAutoModelClass:
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
)
@classmethod
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
return config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
@@ -539,6 +544,10 @@ class _BaseAutoModelClass:
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(

View File

@@ -17,6 +17,7 @@
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
@@ -1658,6 +1659,17 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
@classmethod
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
"""
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
a nested text decoder section, uses that section instead.
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
"""
return config.get_text_config(decoder=True)
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import tempfile
import unittest
import pytest
@@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_flex_attention_with_grads(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
"""
config = self.model_tester.get_config()
model = Gemma3ForConditionalGeneration(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
@slow
@require_torch_gpu