[gemma 3] multimodal checkpoints + AutoModelForCausalLM (#36741)
This commit is contained in:
@@ -444,6 +444,11 @@ class _BaseAutoModelClass:
|
|||||||
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
"""Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses."""
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
@@ -539,6 +544,10 @@ class _BaseAutoModelClass:
|
|||||||
if kwargs_orig.get("quantization_config", None) is not None:
|
if kwargs_orig.get("quantization_config", None) is not None:
|
||||||
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
|
||||||
|
|
||||||
|
# AutoClass-specific config manipulation
|
||||||
|
config = copy.deepcopy(config)
|
||||||
|
config = cls._prepare_config_for_auto_class(config)
|
||||||
|
|
||||||
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
|
||||||
has_local_code = type(config) in cls._model_mapping.keys()
|
has_local_code = type(config) in cls._model_mapping.keys()
|
||||||
trust_remote_code = resolve_trust_remote_code(
|
trust_remote_code = resolve_trust_remote_code(
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .auto_factory import (
|
from .auto_factory import (
|
||||||
_BaseAutoBackboneClass,
|
_BaseAutoBackboneClass,
|
||||||
@@ -1658,6 +1659,17 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
|
|||||||
class AutoModelForCausalLM(_BaseAutoModelClass):
|
class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||||
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
"""
|
||||||
|
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
|
||||||
|
a nested text decoder section, uses that section instead.
|
||||||
|
|
||||||
|
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||||
|
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||||
|
"""
|
||||||
|
return config.get_text_config(decoder=True)
|
||||||
|
|
||||||
|
|
||||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Gemma3 model."""
|
"""Testing suite for the PyTorch Gemma3 model."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -339,6 +340,18 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
def test_flex_attention_with_grads(self):
|
def test_flex_attention_with_grads(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_automodelforcausallm(self):
|
||||||
|
"""
|
||||||
|
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||||
|
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||||
|
"""
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
model = Gemma3ForConditionalGeneration(config)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
|||||||
Reference in New Issue
Block a user