* change order of unmasking of tokens * library import * class setup * test function * refactor * add commit message * test modified * explict initiliasation of weights + made model smaller * removed sepete testing file * fixup * fixup core * test attention mask with token types * tests fixup * removed PaliGemmaAttentionMaskTest class --------- Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
@@ -383,16 +383,20 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
|
|
||||||
|
# First unmask prefix tokens during training
|
||||||
|
if is_training:
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then apply padding mask (will mask pad tokens)
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
|
||||||
if is_training:
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
||||||
)
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||||
|
|||||||
@@ -351,6 +351,47 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
|
|||||||
def test_generate_compile_model_forward(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_attention_mask_with_token_types(self):
|
||||||
|
"""Test that attention masking works correctly both with and without token type IDs."""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Case 1: With token_type_ids
|
||||||
|
outputs_with_types = model(
|
||||||
|
**inputs_dict,
|
||||||
|
output_attentions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Case 2: Without token_type_ids
|
||||||
|
inputs_no_types = {k: v for k, v in inputs_dict.items() if k != "token_type_ids"}
|
||||||
|
outputs_no_types = model(
|
||||||
|
**inputs_no_types,
|
||||||
|
output_attentions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_outputs_with_types = outputs_with_types.attentions
|
||||||
|
attention_outputs_no_types = outputs_no_types.attentions
|
||||||
|
|
||||||
|
# Verify pad tokens remain masked in both cases
|
||||||
|
attention_mask = inputs_dict["attention_mask"]
|
||||||
|
pad_positions = attention_mask == 0
|
||||||
|
|
||||||
|
for layer_attentions in [attention_outputs_with_types, attention_outputs_no_types]:
|
||||||
|
for layer_attn in layer_attentions:
|
||||||
|
# Check if pad tokens are properly masked
|
||||||
|
for batch_idx in range(layer_attn.shape[0]):
|
||||||
|
for seq_idx in range(layer_attn.shape[-1]):
|
||||||
|
if pad_positions[batch_idx, seq_idx]:
|
||||||
|
# Verify attention weights for pad tokens are zero
|
||||||
|
self.assertTrue(
|
||||||
|
torch.all(layer_attn[batch_idx, :, :, seq_idx] == 0),
|
||||||
|
f"Found non-zero attention weights for padding token at batch {batch_idx}, sequence position {seq_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user