* deprecate the prev fix * reword warning and update docs * reword warning * tests * dont bloat `get_text_config()`
This commit is contained in:
committed by
Arthur Zucker
parent
e6ab93e702
commit
d9ccb9adbb
@@ -105,59 +105,75 @@ inputs = processor.apply_chat_template(
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
).to(model.device)
|
).to(model.device)
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=50)
|
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
print(processor.decode(output[0], skip_special_tokens=True))
|
||||||
```
|
```
|
||||||
|
|
||||||
### Multi-image Inference
|
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to.
|
||||||
|
|
||||||
```python
|
```py
|
||||||
model_id = "google/gemma-3-4b-it"
|
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||||
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
|
|
||||||
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
|
|
||||||
|
|
||||||
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
|
||||||
url_stop = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "You are a helpful assistant."}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user", "content": [
|
|
||||||
{"type": "image", "url": url_cow},
|
|
||||||
{"type": "image", "url": url_stop},
|
|
||||||
{"type": "text", "text": "Are these two images identical?"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
inputs = processor.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
add_generation_prompt=True,
|
|
||||||
).to(model.device)
|
|
||||||
|
|
||||||
output = model.generate(**inputs, max_new_tokens=50)
|
|
||||||
print(processor.decode(output[0], skip_special_tokens=True)[inputs.input_ids.shape[1]: ])
|
|
||||||
|
|
||||||
|
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
|
||||||
|
visualizer("<img>What is shown in this image?")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Text-only inference
|
## Notes
|
||||||
|
|
||||||
You can use the VLMs for text-only generation by omitting images in your input. However, you can also load the models in text-only mode as shown below. This will skip loading the vision tower and will save resources when you just need the LLM capabilities.
|
- Use [`Gemma3ForConditionalGeneration`] for image-and-text and image-only inputs.
|
||||||
```python
|
- Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
|
||||||
from transformers import AutoTokenizer, Gemma3ForCausalLM
|
|
||||||
|
|
||||||
model_id = "google/gemma-3-1b-it"
|
```py
|
||||||
|
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||||
|
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
messages =[
|
||||||
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="auto")
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": url_cow},
|
||||||
|
{"type": "image", "url": url_cat},
|
||||||
|
{"type": "text", "text": "Which image is cuter?"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
```
|
||||||
|
- Text passed to the processor should have a `<start_of_image>` token wherever an image should be inserted.
|
||||||
|
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
|
||||||
|
- By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set `do_pan_and_scan=True` to crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.
|
||||||
|
|
||||||
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
|
```diff
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
+ do_pan_and_scan=True,
|
||||||
|
).to("cuda")
|
||||||
|
```
|
||||||
|
- For Gemma-3 1B checkpoint trained in text-only mode, use [`AutoModelForCausalLM`] instead.
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"google/gemma-3-1b-pt",
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"google/gemma-3-1b-pt",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
outputs = model.generate(**input_ids, max_new_tokens=100)
|
outputs = model.generate(**input_ids, max_new_tokens=100)
|
||||||
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|||||||
@@ -1122,7 +1122,9 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
Returns the config that is meant to be used with text IO. On most models, it is the original config instance
|
||||||
itself. On specific composite models, it is under a set of valid names.
|
itself. On specific composite models, it is under a set of valid names.
|
||||||
|
|
||||||
If `decoder` is set to `True`, then only search for decoder config names.
|
Args:
|
||||||
|
decoder (`Optional[bool]`, *optional*, defaults to `False`):
|
||||||
|
If set to `True`, then only search for decoder config names.
|
||||||
"""
|
"""
|
||||||
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
|
||||||
encoder_possible_text_config_names = ("text_encoder",)
|
encoder_possible_text_config_names = ("text_encoder",)
|
||||||
@@ -1144,8 +1146,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
"case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
|
||||||
)
|
)
|
||||||
elif len(valid_text_config_names) == 1:
|
elif len(valid_text_config_names) == 1:
|
||||||
return getattr(self, valid_text_config_names[0])
|
config_to_return = getattr(self, valid_text_config_names[0])
|
||||||
return self
|
else:
|
||||||
|
config_to_return = self
|
||||||
|
return config_to_return
|
||||||
|
|
||||||
|
|
||||||
def get_configuration_file(configuration_files: List[str]) -> str:
|
def get_configuration_file(configuration_files: List[str]) -> str:
|
||||||
|
|||||||
@@ -522,7 +522,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("fuyu", "FuyuForCausalLM"),
|
("fuyu", "FuyuForCausalLM"),
|
||||||
("gemma", "GemmaForCausalLM"),
|
("gemma", "GemmaForCausalLM"),
|
||||||
("gemma2", "Gemma2ForCausalLM"),
|
("gemma2", "Gemma2ForCausalLM"),
|
||||||
("gemma3", "Gemma3ForCausalLM"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
("gemma3_text", "Gemma3ForCausalLM"),
|
("gemma3_text", "Gemma3ForCausalLM"),
|
||||||
("git", "GitForCausalLM"),
|
("git", "GitForCausalLM"),
|
||||||
("glm", "GlmForCausalLM"),
|
("glm", "GlmForCausalLM"),
|
||||||
@@ -1671,7 +1671,20 @@ class AutoModelForCausalLM(_BaseAutoModelClass):
|
|||||||
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
|
||||||
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
|
||||||
"""
|
"""
|
||||||
return config.get_text_config(decoder=True)
|
possible_text_config_names = ("decoder", "generator", "text_config")
|
||||||
|
text_config_names = []
|
||||||
|
for text_config_name in possible_text_config_names:
|
||||||
|
if hasattr(config, text_config_name):
|
||||||
|
text_config_names += [text_config_name]
|
||||||
|
|
||||||
|
text_config = config.get_text_config(decoder=True)
|
||||||
|
if text_config_names and type(text_config) in cls._model_mapping.keys():
|
||||||
|
warnings.warn(
|
||||||
|
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
|
||||||
|
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
||||||
|
|||||||
@@ -344,7 +344,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
|
|
||||||
def test_automodelforcausallm(self):
|
def test_automodelforcausallm(self):
|
||||||
"""
|
"""
|
||||||
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||||
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||||
"""
|
"""
|
||||||
config = self.model_tester.get_config()
|
config = self.model_tester.get_config()
|
||||||
@@ -352,7 +352,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||||
self.assertIsInstance(for_causal_lm, Gemma3ForCausalLM)
|
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
Reference in New Issue
Block a user