Deprecate #36741 and map Causal to Conditional (#36917)

* deprecate the prev fix

* reword warning and update docs

* reword warning

* tests

* dont bloat `get_text_config()`
This commit is contained in:
Raushan Turganbay
2025-03-25 09:13:56 +01:00
committed by Arthur Zucker
parent e6ab93e702
commit d9ccb9adbb
4 changed files with 83 additions and 50 deletions

View File

@@ -105,19 +105,28 @@ inputs = processor.apply_chat_template(
add_generation_prompt=True, add_generation_prompt=True,
).to(model.device) ).to(model.device)
output = model.generate(**inputs, max_new_tokens=50) output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ]) print(processor.decode(output[0], skip_special_tokens=True))
``` ```
### Multi-image Inference Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
```python ```py
model_id = "google/gemma-3-4b-it" from transformers.utils.attention_visualizer import AttentionMaskVisualizer
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
visualizer("<img>What is shown in this image?")
```
## Notes
- Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs.
- Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
```py
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg" url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
messages =[ messages =[
{ {
"role": "system", "role": "system",
@@ -126,38 +135,45 @@ messages = [
] ]
}, },
{ {
"role": "user", "content": [ "role": "user",
"content": [
{"type": "image", "url": url_cow}, {"type": "image", "url": url_cow},
{"type": "image", "url": url_stop}, {"type": "image", "url": url_cat},
{"type": "text", "text": "Are these two images identical?"}, {"type": "text", "text": "Which image is cuter?"},
] ]
}, },
] ]
```
- Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
- By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.
```diff
inputs = processor.apply_chat_template( inputs = processor.apply_chat_template(
messages, messages,
tokenize=True, tokenize=True,
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
add_generation_prompt=True, add_generation_prompt=True,
).to(model.device) + do_pan_and_scan=True,
).to("cuda")
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
``` ```
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
### Text-only inference ```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities. tokenizer = AutoTokenizer.from_pretrained(
```python "google/gemma-3-1b-pt",
from transformers import AutoTokenizer, Gemma3ForCausalLM )
model = AutoModelForCausalLM.from_pretrained(
model_id = "google/gemma-3-1b-it" "google/gemma-3-1b-pt",
torch_dtype=torch.bfloat16,
tokenizer = AutoTokenizer.from_pretrained(model_id) device_map="auto",
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto") attn_implementation="sdpa"
)
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device) input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100) outputs = model.generate(**input_ids, max_new_tokens=100)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True) text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

View File

@@ -1122,7 +1122,9 @@ class PretrainedConfig(PushToHubMixin):
Returns the config that is meant to be used with text IO. On most models, it is the original config instance Returns the config that is meant to be used with text IO. On most models, it is the original config instance
itself. On specific composite models, it is under a set of valid names. itself. On specific composite models, it is under a set of valid names.
If `decoder` is set to `True`, then only search for decoder config names. Args:
decoder (`Optional[bool]`, *optional*, defaults to `False`):
If set to `True`, then only search for decoder config names.
""" """
decoder_possible_text_config_names = ("decoder", "generator", "text_config") decoder_possible_text_config_names = ("decoder", "generator", "text_config")
encoder_possible_text_config_names = ("text_encoder",) encoder_possible_text_config_names = ("text_encoder",)
@@ -1144,8 +1146,10 @@ class PretrainedConfig(PushToHubMixin):
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly." "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
) )
elif len(valid_text_config_names) == 1: elif len(valid_text_config_names) == 1:
return getattr(self, valid_text_config_names[0]) config_to_return = getattr(self, valid_text_config_names[0])
return self else:
config_to_return = self
return config_to_return
def get_configuration_file(configuration_files: List[str]) -> str: def get_configuration_file(configuration_files: List[str]) -> str:

View File

@@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("fuyu", "FuyuForCausalLM"), ("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"), ("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"), ("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"), ("gemma3_text", "Gemma3ForCausalLM"),
("git", "GitForCausalLM"), ("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"), ("glm", "GlmForCausalLM"),
@@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM. config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
""" """
return config.get_text_config(decoder=True) possible_text_config_names = ("decoder", "generator", "text_config")
text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(config, text_config_name):
text_config_names += [text_config_name]
text_config = config.get_text_config(decoder=True)
if text_config_names and type(text_config) in cls._model_mapping.keys():
warnings.warn(
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
FutureWarning,
)
return config
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@@ -344,7 +344,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
def test_automodelforcausallm(self): def test_automodelforcausallm(self):
""" """
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model `AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
""" """
config = self.model_tester.get_config() config = self.model_tester.get_config()
@@ -352,7 +352,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir) for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM) self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
@slow @slow