[gemma3] fix bidirectional attention mask (#38080)
* fix attn mask * attn viz doesn't show yello cubes between images * bucketize made it hard with different number of crops * fixup
This commit is contained in:
committed by
GitHub
parent
2edb0e4b4d
commit
f834d368f6
@@ -1062,10 +1062,21 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
token_type_mask, 0.0
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
||||
@@ -781,10 +781,21 @@ class Gemma3Model(PaliGemmaModel):
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
token_type_mask, 0.0
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
||||
@@ -36,7 +36,9 @@ BLACK_SQUARE = "■"
|
||||
WHITE_SQUARE = "⬚"
|
||||
|
||||
|
||||
def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
|
||||
def generate_attention_matrix_from_mask(
|
||||
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
|
||||
):
|
||||
"""
|
||||
Generates an attention matrix from a given attention mask.
|
||||
|
||||
@@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
for j in range(n)
|
||||
)
|
||||
|
||||
if token_type_ids is not None:
|
||||
is_special = token_type_ids == 1
|
||||
token_type_buckets = torch.where(
|
||||
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
|
||||
)
|
||||
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
|
||||
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
|
||||
|
||||
# Print headers
|
||||
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
|
||||
output.append(" " + legend)
|
||||
@@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
if sliding_window is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
word_repr = repr(word).ljust(max_word_length)
|
||||
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
|
||||
@@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
if sliding_window is not None:
|
||||
sliding_window_row = " ".join(
|
||||
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
||||
if img_token in words[j] and img_token in words[i]
|
||||
if img_token in words[j]
|
||||
and img_token in words[i]
|
||||
and token_type_buckets[0, i] == token_type_buckets[0, j]
|
||||
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
||||
if i == j
|
||||
else BLACK_SQUARE
|
||||
@@ -170,7 +181,8 @@ class AttentionMaskVisualizer:
|
||||
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
|
||||
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
|
||||
img = Image.open(requests.get(img, stream=True).raw)
|
||||
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
|
||||
image_seq_length = 5
|
||||
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
|
||||
if hasattr(processor, "image_token"):
|
||||
image_token = processor.image_token
|
||||
else:
|
||||
@@ -179,7 +191,7 @@ class AttentionMaskVisualizer:
|
||||
if image_token:
|
||||
input_sentence = input_sentence.replace("<img>", image_token)
|
||||
|
||||
inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
|
||||
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
|
||||
|
||||
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
|
||||
|
||||
@@ -223,6 +235,7 @@ class AttentionMaskVisualizer:
|
||||
img_token=self.image_token,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
token_type_ids=kwargs.get("token_type_ids", None),
|
||||
image_seq_length=image_seq_length,
|
||||
)
|
||||
print(f_string)
|
||||
print(f"{top_bottom_border}")
|
||||
|
||||
Reference in New Issue
Block a user