Deprecate #36741 and map Causal to Conditional (#36917)

* deprecate the prev fix

* reword warning and update docs

* reword warning

* tests

* dont bloat `get_text_config()`
This commit is contained in:
Raushan Turganbay
2025-03-25 09:13:56 +01:00
committed by GitHub
parent 2b8a15cc3f
commit 47e5432805
5 changed files with 27 additions and 9 deletions

View File

@@ -204,7 +204,7 @@ visualizer("<img>What is shown in this image?")
+ do_pan_and_scan=True,
).to("cuda")
```
- For text-only inputs, use [`AutoModelForCausalLM`] instead to skip loading the vision components and save resources.
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
```py
import torch

View File

@@ -1118,7 +1118,9 @@ class PretrainedConfig(PushToHubMixin):
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names.
If `decoder` is set to `True`, then only search for decoder config names.
Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
"""
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",)
@@ -1140,8 +1142,10 @@ class PretrainedConfig(PushToHubMixin):
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
)
elif len(valid_text_config_names) == 1:
return getattr(self, valid_text_config_names[0])
return self
config_to_return = getattr(self, valid_text_config_names[0])
else:
config_to_return = self
return config_to_return
def get_configuration_file(configuration_files: list[str]) -> str:

View File

@@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"),
("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"),
@@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
"""
return config.get_text_config(decoder=True)
possible_text_config_names = ("decoder", "generator", "text_config")
text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(config, text_config_name):
text_config_names += [text_config_name]
text_config = config.get_text_config(decoder=True)
if text_config_names and type(text_config) in cls._model_mapping.keys():
warnings.warn(
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
FutureWarning,
)
return config
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@@ -344,7 +344,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_automodelforcausallm(self):
"""
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
"""
config = self.model_tester.get_config()
@@ -352,7 +352,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
@slow

View File

@@ -20,7 +20,7 @@ import unittest
import requests
from transformers import PixtralProcessor
from transformers.testing_utils import require_vision
from transformers.testing_utils import require_read_token, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
@@ -35,6 +35,7 @@ if is_vision_available():
@require_vision
@require_read_token
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""