[gemma3] fix bidirectional image mask (#39396)

* fix gemma3 mask

* make compile happy, and use only torch ops

* no full attention between images

* update tests

* fix tests

* add a fast test
This commit is contained in:
Raushan Turganbay
2025-07-22 10:04:56 +02:00
committed by GitHub
parent fbeaf96f9e
commit 3bc726b381
4 changed files with 112 additions and 21 deletions

View File

@@ -270,6 +270,45 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
self.model_tester = Gemma3Vision2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
def test_bidirectional_image_attention(self):
"""
Tests that each image can attend to itself bidirectionally. However an image
cannot attend to future images, even within the same batch.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "eager"
model = Gemma3Model(config).to(torch_device)
# First let's pass inputs without change which is one image per text and manipulate
# `token_type_ids` to make sure bidirectional mask is applied where it has to be
inputs_dict["token_type_ids"] = torch.zeros_like(inputs_dict["token_type_ids"])
inputs_dict["token_type_ids"][:, :4] = 1 # unmask first 4 tokens
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
# We expect a non-causal mask on first 4 tokens, thus no zeros
for attention in out.attentions:
self.assertTrue((attention[..., :4, :4] != 0).all().item())
# Now when removing `token_type_ids`, we will get simple causal mask
inputs_dict["token_type_ids"][:, :4] = 0 # mask back first 4 tokens
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
# We expect a causal mask on first 4 tokens, thus no zeros
for attention in out.attentions:
self.assertFalse((attention[..., :4, :4] != 0).all().item())
# Let's add two "images" per text, first one spanning 4 tokens and last one 3 tokens
inputs_dict["token_type_ids"][:, :4] = 1
inputs_dict["token_type_ids"][:, 7:10] = 1
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
for attention in out.attentions:
self.assertTrue((attention[..., :4, :4] != 0).all().item())
self.assertTrue((attention[..., 7:10, 7:10] != 0).all().item())
# We expect a non-causal mask only within same image and no looking ahead to the future
self.assertTrue((attention[..., :4, 7:10] == 0).all().item())
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@@ -413,7 +452,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'],
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
}
) # fmt: skip
@@ -463,8 +502,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
],
("cuda", 8):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks',
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach'
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown"
],
("rocm", (9, 5)):
[
@@ -508,7 +547,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
("cuda", 7): [],
("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
("cuda", 8): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
@@ -565,8 +604,8 @@ class Gemma3IntegrationTest(unittest.TestCase):
],
("cuda", 7): [],
("cuda", 8): [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
],
("rocm", (9, 5)) : [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
@@ -610,7 +649,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
{
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
("cuda", 7): [],
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()