From 950cfb0b4f6f79dbeb658100d25b5f3e3d8be04b Mon Sep 17 00:00:00 2001 From: Sambhav Dixit <94298612+sambhavnoobcoder@users.noreply.github.com> Date: Thu, 13 Feb 2025 14:41:44 +0530 Subject: [PATCH] Fix PaliGemma Pad Token Masking During Training #35855 (#35859) * change order of unmasking of tokens * library import * class setup * test function * refactor * add commit message * test modified * explict initiliasation of weights + made model smaller * removed sepete testing file * fixup * fixup core * test attention mask with token types * tests fixup * removed PaliGemmaAttentionMaskTest class --------- Co-authored-by: sambhavnoobcoder --- .../models/paligemma/modeling_paligemma.py | 14 ++++--- .../paligemma/test_modeling_paligemma.py | 41 +++++++++++++++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index b6dab1830c..9172b98c06 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -383,16 +383,20 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - # we are training thus we need to create a full mask on the image + prefix but causal on suffix - if is_training: - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 - ) + return causal_mask def get_image_features(self, pixel_values: torch.FloatTensor): diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 7c72b03a5d..7d686576bd 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -351,6 +351,47 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes def test_generate_compile_model_forward(self): pass + def test_attention_mask_with_token_types(self): + """Test that attention masking works correctly both with and without token type IDs.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + # Case 1: With token_type_ids + outputs_with_types = model( + **inputs_dict, + output_attentions=True, + ) + + # Case 2: Without token_type_ids + inputs_no_types = {k: v for k, v in inputs_dict.items() if k != "token_type_ids"} + outputs_no_types = model( + **inputs_no_types, + output_attentions=True, + ) + + attention_outputs_with_types = outputs_with_types.attentions + attention_outputs_no_types = outputs_no_types.attentions + + # Verify pad tokens remain masked in both cases + attention_mask = inputs_dict["attention_mask"] + pad_positions = attention_mask == 0 + + for layer_attentions in [attention_outputs_with_types, attention_outputs_no_types]: + for layer_attn in layer_attentions: + # Check if pad tokens are properly masked + for batch_idx in range(layer_attn.shape[0]): + for seq_idx in range(layer_attn.shape[-1]): + if pad_positions[batch_idx, seq_idx]: + # Verify attention weights for pad tokens are zero + self.assertTrue( + torch.all(layer_attn[batch_idx, :, :, seq_idx] == 0), + f"Found non-zero attention weights for padding token at batch {batch_idx}, sequence position {seq_idx}", + ) + @slow @require_torch