* deprecate the prev fix * reword warning and update docs * reword warning * tests * dont bloat `get_text_config()`
This commit is contained in:
committed by
GitHub
parent
2b8a15cc3f
commit
47e5432805
@@ -204,7 +204,7 @@ visualizer("<img>What is shown in this image?")
|
||||
+ do_pan_and_scan=True,
|
||||
).to("cuda")
|
||||
```
|
||||
- For text-only inputs, use [`AutoModelForCausalLM`] instead to skip loading the vision components and save resources.
|
||||
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
@@ -1118,7 +1118,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||
itself. On specific composite models, it is under a set of valid names.
|
||||
|
||||
If `decoder` is set to `True`, then only search for decoder config names.
|
||||
Args:
|
||||
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||
If set to `True`, then only search for decoder config names.
|
||||
"""
|
||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
encoder_possible_text_config_names = ("text_encoder",)
|
||||
@@ -1140,8 +1142,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||
)
|
||||
elif len(valid_text_config_names) == 1:
|
||||
return getattr(self, valid_text_config_names[0])
|
||||
return self
|
||||
config_to_return = getattr(self, valid_text_config_names[0])
|
||||
else:
|
||||
config_to_return = self
|
||||
return config_to_return
|
||||
|
||||
|
||||
def get_configuration_file(configuration_files: list[str]) -> str:
|
||||
|
||||
@@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
("gemma2", "Gemma2ForCausalLM"),
|
||||
("gemma3", "Gemma3ForCausalLM"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("gemma3_text", "Gemma3ForCausalLM"),
|
||||
("git", "GitForCausalLM"),
|
||||
("glm", "GlmForCausalLM"),
|
||||
@@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
|
||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||
"""
|
||||
return config.get_text_config(decoder=True)
|
||||
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||
text_config_names = []
|
||||
for text_config_name in possible_text_config_names:
|
||||
if hasattr(config, text_config_name):
|
||||
text_config_names += [text_config_name]
|
||||
|
||||
text_config = config.get_text_config(decoder=True)
|
||||
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||
warnings.warn(
|
||||
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||
FutureWarning,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
@@ -344,7 +344,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||
"""
|
||||
config = self.model_tester.get_config()
|
||||
@@ -352,7 +352,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
||||
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -20,7 +20,7 @@ import unittest
|
||||
import requests
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@@ -35,6 +35,7 @@ if is_vision_available():
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user