[gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741)
This commit is contained in:
@@ -444,6 +444,11 @@ class _BaseAutoModelClass:
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
@@ -539,6 +544,10 @@ class _BaseAutoModelClass:
|
||||
if kwargs_orig.get("quantization_config", None) is not None:
|
||||
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
||||
|
||||
# AutoClass-specific config manipulation
|
||||
config = copy.deepcopy(config)
|
||||
config = cls._prepare_config_for_auto_class(config)
|
||||
|
||||
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
||||
has_local_code = type(config) in cls._model_mapping.keys()
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from .auto_factory import (
|
||||
_BaseAutoBackboneClass,
|
||||
@@ -1658,6 +1659,17 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
|
||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
@classmethod
|
||||
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""
|
||||
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
|
||||
a nested text decoder section, uses that section instead.
|
||||
|
||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||
"""
|
||||
return config.get_text_config(decoder=True)
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma3 model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||
"""
|
||||
config = self.model_tester.get_config()
|
||||
model = Gemma3ForConditionalGeneration(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user