Fix PaliGemma Pad Token Masking During Training #35855 (#35859)

* change order of unmasking of tokens

* library import

* class setup

* test function

* refactor

* add commit message

* test modified

* explict initiliasation of weights + made model smaller

* removed sepete testing file

* fixup

* fixup core

* test attention mask with token types

* tests fixup

* removed PaliGemmaAttentionMaskTest class

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
Sambhav Dixit
2025-02-13 14:41:44 +05:30
committed by GitHub
parent 1614d196e8
commit 950cfb0b4f
2 changed files with 50 additions and 5 deletions

View File

@@ -351,6 +351,47 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes
def test_generate_compile_model_forward(self):
pass
def test_attention_mask_with_token_types(self):
"""Test that attention masking works correctly both with and without token type IDs."""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
# Case 1: With token_type_ids
outputs_with_types = model(
**inputs_dict,
output_attentions=True,
)
# Case 2: Without token_type_ids
inputs_no_types = {k: v for k, v in inputs_dict.items() if k != "token_type_ids"}
outputs_no_types = model(
**inputs_no_types,
output_attentions=True,
)
attention_outputs_with_types = outputs_with_types.attentions
attention_outputs_no_types = outputs_no_types.attentions
# Verify pad tokens remain masked in both cases
attention_mask = inputs_dict["attention_mask"]
pad_positions = attention_mask == 0
for layer_attentions in [attention_outputs_with_types, attention_outputs_no_types]:
for layer_attn in layer_attentions:
# Check if pad tokens are properly masked
for batch_idx in range(layer_attn.shape[0]):
for seq_idx in range(layer_attn.shape[-1]):
if pad_positions[batch_idx, seq_idx]:
# Verify attention weights for pad tokens are zero
self.assertTrue(
torch.all(layer_attn[batch_idx, :, :, seq_idx] == 0),
f"Found non-zero attention weights for padding token at batch {batch_idx}, sequence position {seq_idx}",
)
@slow
@require_torch