From 3bc726b381592601cd9dd0fdcff5edcb02f3a85b Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 22 Jul 2025 10:04:56 +0200 Subject: [PATCH] [gemma3] fix bidirectional image mask (#39396) * fix gemma3 mask * make compile happy, and use only torch ops * no full attention between images * update tests * fix tests * add a fast test --- src/transformers/generation/utils.py | 4 +- .../models/gemma3/modeling_gemma3.py | 38 ++++++++++--- .../models/gemma3/modular_gemma3.py | 38 ++++++++++--- tests/models/gemma3/test_modeling_gemma3.py | 53 ++++++++++++++++--- 4 files changed, 112 insertions(+), 21 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 76b3d7bd8a..e360acdac3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -646,8 +646,8 @@ class GenerationMixin(ContinuousMixin): # If it's not defined, it means the model uses the new general mask API if causal_mask_creation_function is None: # can't be found - token_type_ids = getattr(model_input, "token_type_ids", None) - position_ids = getattr(model_input, position_ids_key, None) + token_type_ids = model_inputs.get("token_type_ids", None) + position_ids = model_inputs.get(position_ids_key, None) # Some models may overwrite the general one causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) attention_mask = causal_mask_creation_function( diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index c02f2862c5..394e380021 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -737,7 +737,11 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: +def token_type_ids_mask_function( + token_type_ids: Optional[torch.Tensor], + image_group_ids: Optional[torch.Tensor], + tokens_per_image: int, +) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -747,10 +751,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_ return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If the difference is less than image size, both are part of the same image block - same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image # If it's 1 for both query and key/value, we are in an image block - is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length + # Since vmap doesn't support `if statement` we workaround it with `torch.where` + safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) + token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] + token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) + + image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx] + image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1) + + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) + same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx # This is bidirectional attention whenever we are dealing with image tokens return is_image_block & same_image_block @@ -915,8 +927,15 @@ class Gemma3Model(Gemma3PreTrainedModel): } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image + token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image ) # Create the masks @@ -1181,8 +1200,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device), config.mm_tokens_per_image + token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image ) return create_masks_for_generate(**mask_kwargs) diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index c0a4f390c5..57ecedca91 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -716,7 +716,11 @@ class Gemma3MultiModalProjector(nn.Module): return projected_vision_outputs.type_as(vision_outputs) -def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]: +def token_type_ids_mask_function( + token_type_ids: Optional[torch.Tensor], + image_group_ids: Optional[torch.Tensor], + tokens_per_image: int, +) -> Optional[Callable]: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, not start and end indices. @@ -726,10 +730,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_ return None def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - # If the difference is less than image size, both are part of the same image block - same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image # If it's 1 for both query and key/value, we are in an image block - is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1) + # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length + # Since vmap doesn't support `if statement` we workaround it with `torch.where` + safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) + token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx] + token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) + + image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx] + image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1) + + is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) + same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx # This is bidirectional attention whenever we are dealing with image tokens return is_image_block & same_image_block @@ -840,8 +852,15 @@ class Gemma3Model(PaliGemmaModel): } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image + token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image ) # Create the masks @@ -1062,8 +1081,15 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` + + # First find where a new image block starts: 1 if image and previous not image + # The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally + is_image = (token_type_ids == 1).to(cache_position.device) + new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1] + image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1 + image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1)) mask_kwargs["or_mask_function"] = token_type_ids_mask_function( - token_type_ids.to(cache_position.device), config.mm_tokens_per_image + token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image ) return create_masks_for_generate(**mask_kwargs) diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 2a1314c85a..3817acfd50 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -270,6 +270,45 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte self.model_tester = Gemma3Vision2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37) + def test_bidirectional_image_attention(self): + """ + Tests that each image can attend to itself bidirectionally. However an image + cannot attend to future images, even within the same batch. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + model = Gemma3Model(config).to(torch_device) + + # First let's pass inputs without change which is one image per text and manipulate + # `token_type_ids` to make sure bidirectional mask is applied where it has to be + inputs_dict["token_type_ids"] = torch.zeros_like(inputs_dict["token_type_ids"]) + inputs_dict["token_type_ids"][:, :4] = 1 # unmask first 4 tokens + with torch.no_grad(): + out = model(**inputs_dict, output_attentions=True) + # We expect a non-causal mask on first 4 tokens, thus no zeros + for attention in out.attentions: + self.assertTrue((attention[..., :4, :4] != 0).all().item()) + + # Now when removing `token_type_ids`, we will get simple causal mask + inputs_dict["token_type_ids"][:, :4] = 0 # mask back first 4 tokens + with torch.no_grad(): + out = model(**inputs_dict, output_attentions=True) + # We expect a causal mask on first 4 tokens, thus no zeros + for attention in out.attentions: + self.assertFalse((attention[..., :4, :4] != 0).all().item()) + + # Let's add two "images" per text, first one spanning 4 tokens and last one 3 tokens + inputs_dict["token_type_ids"][:, :4] = 1 + inputs_dict["token_type_ids"][:, 7:10] = 1 + with torch.no_grad(): + out = model(**inputs_dict, output_attentions=True) + for attention in out.attentions: + self.assertTrue((attention[..., :4, :4] != 0).all().item()) + self.assertTrue((attention[..., 7:10, 7:10] != 0).all().item()) + + # We expect a non-causal mask only within same image and no looking ahead to the future + self.assertTrue((attention[..., :4, 7:10] == 0).all().item()) + @unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training") def test_training_gradient_checkpointing(self): pass @@ -413,7 +452,7 @@ class Gemma3IntegrationTest(unittest.TestCase): EXPECTED_TEXTS = Expectations( { ("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'], - ("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'], + ("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'], ("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'], } ) # fmt: skip @@ -463,8 +502,8 @@ class Gemma3IntegrationTest(unittest.TestCase): ], ("cuda", 8): [ - 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks', - 'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach' + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown" ], ("rocm", (9, 5)): [ @@ -508,7 +547,7 @@ class Gemma3IntegrationTest(unittest.TestCase): { ("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'], ("cuda", 7): [], - ("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'], + ("cuda", 8): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"], } ) # fmt: skip EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() @@ -565,8 +604,8 @@ class Gemma3IntegrationTest(unittest.TestCase): ], ("cuda", 7): [], ("cuda", 8): [ - 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.', - 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a', + "user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the", + 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a' ], ("rocm", (9, 5)) : [ 'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.', @@ -610,7 +649,7 @@ class Gemma3IntegrationTest(unittest.TestCase): { ("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"], ("cuda", 7): [], - ("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"], + ("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"], } ) # fmt: skip EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()