* deprecate the prev fix * reword warning and update docs * reword warning * tests * dont bloat `get_text_config()`
This commit is contained in:
committed by
GitHub
parent
2b8a15cc3f
commit
47e5432805
@@ -204,7 +204,7 @@ visualizer("<img>What is shown in this image?")
|
|||||||
+ do_pan_and_scan=True,
|
+ do_pan_and_scan=True,
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
```
|
```
|
||||||
- For text-only inputs, use [`AutoModelForCausalLM`] instead to skip loading the vision components and save resources.
|
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1118,7 +1118,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||||
itself. On specific composite models, it is under a set of valid names.
|
itself. On specific composite models, it is under a set of valid names.
|
||||||
|
|
||||||
If `decoder` is set to `True`, then only search for decoder config names.
|
Args:
|
||||||
|
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, then only search for decoder config names.
|
||||||
"""
|
"""
|
||||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||||
encoder_possible_text_config_names = ("text_encoder",)
|
encoder_possible_text_config_names = ("text_encoder",)
|
||||||
@@ -1140,8 +1142,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||||
)
|
)
|
||||||
elif len(valid_text_config_names) == 1:
|
elif len(valid_text_config_names) == 1:
|
||||||
return getattr(self, valid_text_config_names[0])
|
config_to_return = getattr(self, valid_text_config_names[0])
|
||||||
return self
|
else:
|
||||||
|
config_to_return = self
|
||||||
|
return config_to_return
|
||||||
|
|
||||||
|
|
||||||
def get_configuration_file(configuration_files: list[str]) -> str:
|
def get_configuration_file(configuration_files: list[str]) -> str:
|
||||||
|
|||||||
@@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("fuyu", "FuyuForCausalLM"),
|
("fuyu", "FuyuForCausalLM"),
|
||||||
("gemma", "GemmaForCausalLM"),
|
("gemma", "GemmaForCausalLM"),
|
||||||
("gemma2", "Gemma2ForCausalLM"),
|
("gemma2", "Gemma2ForCausalLM"),
|
||||||
("gemma3", "Gemma3ForCausalLM"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
("gemma3_text", "Gemma3ForCausalLM"),
|
("gemma3_text", "Gemma3ForCausalLM"),
|
||||||
("git", "GitForCausalLM"),
|
("git", "GitForCausalLM"),
|
||||||
("glm", "GlmForCausalLM"),
|
("glm", "GlmForCausalLM"),
|
||||||
@@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
|
|||||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||||
"""
|
"""
|
||||||
return config.get_text_config(decoder=True)
|
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||||
|
text_config_names = []
|
||||||
|
for text_config_name in possible_text_config_names:
|
||||||
|
if hasattr(config, text_config_name):
|
||||||
|
text_config_names += [text_config_name]
|
||||||
|
|
||||||
|
text_config = config.get_text_config(decoder=True)
|
||||||
|
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||||
|
warnings.warn(
|
||||||
|
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||||
|
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||||
|
|||||||
@@ -344,7 +344,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
|
|
||||||
def test_automodelforcausallm(self):
|
def test_automodelforcausallm(self):
|
||||||
"""
|
"""
|
||||||
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||||
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||||
"""
|
"""
|
||||||
config = self.model_tester.get_config()
|
config = self.model_tester.get_config()
|
||||||
@@ -352,7 +352,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||||
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import unittest
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from transformers import PixtralProcessor
|
from transformers import PixtralProcessor
|
||||||
from transformers.testing_utils import require_vision
|
from transformers.testing_utils import require_read_token, require_vision
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_processing_common import ProcessorTesterMixin
|
from ...test_processing_common import ProcessorTesterMixin
|
||||||
@@ -35,6 +35,7 @@ if is_vision_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
|
@require_read_token
|
||||||
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||||
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user