[gemma3] fix bidirectional attention mask (#38080)

* fix attn mask

* attn viz doesn't show yello cubes between images

* bucketize made it hard with different number of crops

* fixup
This commit is contained in:
Raushan Turganbay
2025-05-20 17:35:04 +02:00
committed by GitHub
parent 2edb0e4b4d
commit f834d368f6
3 changed files with 44 additions and 9 deletions

View File

@@ -1062,10 +1062,21 @@ class Gemma3Model(Gemma3PreTrainedModel):
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)
if attention_mask is not None:

View File

@@ -781,10 +781,21 @@ class Gemma3Model(PaliGemmaModel):
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)
if attention_mask is not None:

View File

@@ -36,7 +36,9 @@ BLACK_SQUARE = "■"
WHITE_SQUARE = ""
def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
def generate_attention_matrix_from_mask(
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
):
"""
Generates an attention matrix from a given attention mask.
@@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
for j in range(n)
)
if token_type_ids is not None:
is_special = token_type_ids == 1
token_type_buckets = torch.where(
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
)
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
# Print headers
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
output.append(" " + legend)
@@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None
else ""
)
for i, word in enumerate(words):
word_repr = repr(word).ljust(max_word_length)
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
@@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None:
sliding_window_row = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if img_token in words[j] and img_token in words[i]
if img_token in words[j]
and img_token in words[i]
and token_type_buckets[0, i] == token_type_buckets[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
@@ -170,7 +181,8 @@ class AttentionMaskVisualizer:
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
img = Image.open(requests.get(img, stream=True).raw)
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
image_seq_length = 5
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
if hasattr(processor, "image_token"):
image_token = processor.image_token
else:
@@ -179,7 +191,7 @@ class AttentionMaskVisualizer:
if image_token:
input_sentence = input_sentence.replace("<img>", image_token)
inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
@@ -223,6 +235,7 @@ class AttentionMaskVisualizer:
img_token=self.image_token,
sliding_window=getattr(self.config, "sliding_window", None),
token_type_ids=kwargs.get("token_type_ids", None),
image_seq_length=image_seq_length,
)
print(f_string)
print(f"{top_bottom_border}")