From 7c233980f4a8b8d4a8426e2072cb9dcdfaefc4f7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 19 Mar 2025 15:04:19 +0000 Subject: [PATCH] [gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741) --- src/transformers/models/auto/auto_factory.py | 9 +++++++++ src/transformers/models/auto/modeling_auto.py | 12 ++++++++++++ tests/models/gemma3/test_modeling_gemma3.py | 13 +++++++++++++ 3 files changed, 34 insertions(+) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 069269bb7f..c06017fd76 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -444,6 +444,11 @@ class _BaseAutoModelClass: f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." ) + @classmethod + def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: + """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" + return config + @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): config = kwargs.pop("config", None) @@ -539,6 +544,10 @@ class _BaseAutoModelClass: if kwargs_orig.get("quantization_config", None) is not None: kwargs["quantization_config"] = kwargs_orig["quantization_config"] + # AutoClass-specific config manipulation + config = copy.deepcopy(config) + config = cls._prepare_config_for_auto_class(config) + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping.keys() trust_remote_code = resolve_trust_remote_code( diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 90dee0eb58..f08fc2fa04 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -17,6 +17,7 @@ import warnings from collections import OrderedDict +from ...configuration_utils import PretrainedConfig from ...utils import logging from .auto_factory import ( _BaseAutoBackboneClass, @@ -1658,6 +1659,17 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag class AutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING + @classmethod + def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: + """ + Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has + a nested text decoder section, uses that section instead. + + Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own + config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM. + """ + return config.get_text_config(decoder=True) + AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 3586d35cb3..a5a44330d8 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -14,6 +14,7 @@ # limitations under the License. """Testing suite for the PyTorch Gemma3 model.""" +import tempfile import unittest import pytest @@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte def test_flex_attention_with_grads(self): pass + def test_automodelforcausallm(self): + """ + Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that + `AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model + """ + config = self.model_tester.get_config() + model = Gemma3ForConditionalGeneration(config) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir) + self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM) + @slow @require_torch_gpu