From fef8b7f8e9fd2800f5b4cb7483f9889bd437257e Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 19 Mar 2025 13:58:46 +0100 Subject: [PATCH] Add attention visualization tool (#36630) * add utils fiel * style * nits * nits * update * updaets * update * fix init issues * big updates * nits * nits? * small updates * nites * there were still some models left * style * fixes * updates * nits _ fixes * push changes * update * update * update * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * style * styling and return a string for testing * small updates * always biderectional for now * update --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- src/transformers/models/aria/modeling_aria.py | 2 +- .../models/bloom/modeling_bloom.py | 2 +- .../models/chameleon/modeling_chameleon.py | 2 +- .../models/codegen/modeling_codegen.py | 2 +- .../models/cohere/modeling_cohere.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- .../models/diffllama/modeling_diffllama.py | 2 +- src/transformers/models/emu3/modeling_emu3.py | 2 +- .../models/gemma/modeling_gemma.py | 2 +- .../models/gemma3/processing_gemma3.py | 1 + src/transformers/models/glm/modeling_glm.py | 2 +- .../models/gpt_neo/modeling_gpt_neo.py | 2 +- .../models/gpt_neox/modeling_gpt_neox.py | 2 +- .../modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- .../models/granite/modeling_granite.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../modeling_granitemoeshared.py | 2 +- .../models/helium/modeling_helium.py | 2 +- .../models/idefics/modeling_idefics.py | 2 +- .../models/jetmoe/modeling_jetmoe.py | 2 +- .../models/llama/modeling_llama.py | 2 +- .../models/longt5/modeling_longt5.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mistral/modular_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- .../models/mllama/modeling_mllama.py | 2 +- .../models/moonshine/modeling_moonshine.py | 2 +- .../models/moshi/modeling_moshi.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/opt/modeling_opt.py | 2 +- .../models/paligemma/modeling_paligemma.py | 17 +- .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../models/phimoe/modeling_phimoe.py | 2 +- .../models/pix2struct/modeling_pix2struct.py | 2 +- .../models/pop2piano/modeling_pop2piano.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 2 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- .../modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- .../models/whisper/modeling_whisper.py | 2 +- src/transformers/processing_utils.py | 2 +- .../utils/attention_visualizer.py | 229 ++++++++++++++++++ 55 files changed, 294 insertions(+), 59 deletions(-) create mode 100644 src/transformers/utils/attention_visualizer.py diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 252ddb694f..e44cd709df 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1015,7 +1015,7 @@ class AriaTextModel(AriaTextPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index da9528cbc4..c5783a3b13 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -746,7 +746,7 @@ class BloomModel(BloomPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index bb09eb26cc..30a8ab60f2 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1390,7 +1390,7 @@ class ChameleonModel(ChameleonPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 89e46a5523..1d7d3b6d27 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -592,7 +592,7 @@ class CodeGenModel(CodeGenPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 018fc3b862..01b5ee3b9e 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -665,7 +665,7 @@ class CohereModel(CoherePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 4466528948..08f11ed6ff 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1119,7 +1119,7 @@ class DbrxModel(DbrxPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index f490525d99..d69dcf3b59 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -904,7 +904,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 8ad29c02ad..47360222fc 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1483,7 +1483,7 @@ class Emu3TextModel(Emu3PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 0830527472..807abe2b9f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -637,7 +637,7 @@ class GemmaModel(GemmaPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gemma3/processing_gemma3.py b/src/transformers/models/gemma3/processing_gemma3.py index 00cece56a2..c5492f3b76 100644 --- a/src/transformers/models/gemma3/processing_gemma3.py +++ b/src/transformers/models/gemma3/processing_gemma3.py @@ -65,6 +65,7 @@ class Gemma3Processor(ProcessorMixin): self.image_seq_length = image_seq_length self.image_token_id = tokenizer.image_token_id self.boi_token = tokenizer.boi_token + self.image_token = tokenizer.boi_token image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length) self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 53f488116d..d155beff77 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -646,7 +646,7 @@ class GlmModel(GlmPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index e72fff18e0..13ff3cd740 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -796,7 +796,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 9bbd94d798..97590faf70 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -640,7 +640,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 3e13ce09c5..aab7183ccf 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -668,7 +668,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 24f224ad1b..effaa3b725 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -895,7 +895,7 @@ class GPTJModel(GPTJPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 1822bd627d..e7a95c719f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -649,7 +649,7 @@ class GraniteModel(GranitePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 73086f958c..192c99df7f 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1122,7 +1122,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 3dc86991c7..d97620e62b 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -1067,7 +1067,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index c3f57149d8..40aac3b933 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -633,7 +633,7 @@ class HeliumModel(HeliumPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 3ca196936c..da18b35d48 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1367,7 +1367,7 @@ class IdeficsModel(IdeficsPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 814c62a5fd..9518ef0be5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1128,7 +1128,7 @@ class JetMoeModel(JetMoePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 159f41b3ce..e52d4087be 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -635,7 +635,7 @@ class LlamaModel(LlamaPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 9dce316236..70ec2db49f 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1604,7 +1604,7 @@ class LongT5Stack(LongT5PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 8eece150cf..4539cfd0bd 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1066,7 +1066,7 @@ class MimiTransformerModel(nn.Module): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4c0f6b8b68..07a3e7ed41 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -600,7 +600,7 @@ class MistralModel(MistralPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 10337f4eef..20f0528627 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -118,7 +118,7 @@ class MistralModel(LlamaModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4112bed373..8fd3f7540e 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -734,7 +734,7 @@ class MixtralModel(MixtralPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index fb0ff7d045..1981f4287b 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1081,7 +1081,7 @@ class MllamaPreTrainedModel(PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 963b7e7aa2..fc4770fd8f 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -999,7 +999,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 6a6151b678..3d686ba34d 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1296,7 +1296,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: @@ -1610,7 +1610,7 @@ class MoshiModel(MoshiPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 9c2d23b7af..558b671234 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1195,7 +1195,7 @@ class MT5Stack(MT5PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index ef73c55e0f..d6d77887a2 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -883,7 +883,7 @@ class NemotronModel(NemotronPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2f6dc1fa9c..626b248e2c 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -611,7 +611,7 @@ class OlmoModel(OlmoPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 05b8d17223..101f79750e 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -612,7 +612,7 @@ class Olmo2Model(Olmo2PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index a306e84b23..28134e141c 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -643,7 +643,7 @@ class OPTDecoder(OPTPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index c49e91b282..5d8542c1d3 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -340,19 +340,22 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def _update_causal_mask( self, attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, + token_type_ids=None, + past_key_values=None, + cache_position=None, + input_tensor=None, + is_training: bool = None, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None - + is_training = is_training if is_training is not None else self.training using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min + if input_tensor is None: + input_tensor = attention_mask + inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -387,6 +390,8 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi # First unmask prefix tokens during training if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index df6153c811..a71169dbdd 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -683,7 +683,7 @@ class PersimmonModel(PersimmonPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index cac12d59b0..c273f90628 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -609,7 +609,7 @@ class PhiModel(PhiPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 6712f24f41..8b49ad0876 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -675,7 +675,7 @@ class Phi3Model(Phi3PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index f06255030e..66452fd943 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1180,7 +1180,7 @@ class PhimoeModel(PhimoePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 42e190d2e4..63c392db12 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1591,7 +1591,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 0c4a2eda97..cadf71871e 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1004,7 +1004,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 53f675192b..d5d922e208 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -613,7 +613,7 @@ class Qwen2Model(Qwen2PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index a90baf246e..fb0c8f5dbe 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1246,7 +1246,7 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 036bdb26ca..80e215854b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1068,7 +1068,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dfea703003..66b780f312 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1192,7 +1192,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 4621851fea..2f59cda241 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -938,7 +938,7 @@ class StableLmModel(StableLmPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 0187f733ab..a23dfa9d46 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -596,7 +596,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a347654a49..73282a1509 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1140,7 +1140,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index ba96c10ed0..0851c0eac9 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1209,7 +1209,7 @@ class T5Stack(T5PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 2b0e27f45e..089434ca82 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1542,7 +1542,7 @@ class UdopStack(UdopPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 7b868696f6..07c44bef7b 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -852,7 +852,7 @@ class UMT5Stack(UMT5PreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 2bcf4026a3..dacd147c15 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1378,7 +1378,7 @@ class WhisperDecoder(WhisperPreTrainedModel): input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index c85abdc739..4b0c2ebdd4 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1069,7 +1069,7 @@ class ProcessorMixin(PushToHubMixin): args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs) - + processor_dict.update({k: v for k, v in kwargs.items() if k in processor_dict.keys()}) return cls.from_args_and_dict(args, processor_dict, **kwargs) @classmethod diff --git a/src/transformers/utils/attention_visualizer.py b/src/transformers/utils/attention_visualizer.py new file mode 100644 index 0000000000..54efdbfcda --- /dev/null +++ b/src/transformers/utils/attention_visualizer.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import requests +from PIL import Image + +from ..models.auto.auto_factory import _get_model_class +from ..models.auto.configuration_auto import AutoConfig +from ..models.auto.modeling_auto import MODEL_FOR_PRETRAINING_MAPPING, MODEL_MAPPING +from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES, AutoProcessor +from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer +from .import_utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.nn as nn + +# Print the matrix with words as row labels +GREEN = "\033[92m" +YELLOW = "\033[93m" +RESET = "\033[0m" +BLACK_SQUARE = "■" +WHITE_SQUARE = "⬚" + + +def generate_attention_matrix_from_mask(words, mask, img_token="", sliding_window=None, token_type_ids=None): + """ + Generates an attention matrix from a given attention mask. + + Optionally applies a sliding window mask (e.g., for Gemma2/3) and + marks regions where image tokens occur based on the specified `img_token`. + """ + mask = mask.int() + if mask.ndim == 3: + mask = mask[0, :, :] + if mask.ndim == 4: + mask = mask[0, 0, :, :] + + n = len(words) + max_word_length = max(len(repr(word)) for word in words) + first_img_idx = 0 + output = [] + + for i, k in enumerate(words): + if k == img_token and not first_img_idx: + first_img_idx = i + mask[i, i] = 2 # Mark yellow regions + if first_img_idx > 0 and (k != img_token or i == n - 1): + if i == n - 1: + i += 1 + mask[first_img_idx:i, first_img_idx:i] = 2 # Mark yellow regions + first_img_idx = 0 + + # Generate sliding window mask (size = 4), excluding img_token + sliding_window_mask = None + if sliding_window is not None: + sliding_window_mask = [[1 if (0 <= i - j < sliding_window) else 0 for j in range(n)] for i in range(n)] + + row_dummy = " ".join( + f"{YELLOW}{BLACK_SQUARE}{RESET}" + if mask[0, j] + else f"{GREEN}{BLACK_SQUARE}{RESET}" + if 0 == j + else BLACK_SQUARE + if mask[0, j] + else WHITE_SQUARE + for j in range(n) + ) + + # Print headers + legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids" + output.append(" " + legend) + f_string = " " * (max_word_length + 5) + "Attention Matrix".ljust(len(row_dummy) // 2) + if sliding_window is not None: + f_string += "Sliding Window Mask" + output.append(f_string) + + vertical_header = [] + for idx, word in enumerate(words): + if mask[idx, idx] == 2: + vertical_header.append([f"{YELLOW}{k}{RESET}" for k in list(str(idx).rjust(len(str(n))))]) + else: + vertical_header.append(list(str(idx).rjust(len(str(n))))) + + vertical_header = list(map(list, zip(*vertical_header))) # Transpose + + for row in vertical_header: + output.append( + (max_word_length + 5) * " " + " ".join(row) + " | " + " ".join(row) + if sliding_window is not None + else "" + ) + + for i, word in enumerate(words): + word_repr = repr(word).ljust(max_word_length) + colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr + row_display = " ".join( + f"{YELLOW}{BLACK_SQUARE}{RESET}" + if img_token in words[j] and mask[i, j] and img_token in words[i] + else f"{GREEN}{BLACK_SQUARE}{RESET}" + if i == j + else BLACK_SQUARE + if mask[i, j] + else WHITE_SQUARE + for j in range(n) + ) + sliding_window_row = "" + if sliding_window is not None: + sliding_window_row = " ".join( + f"{YELLOW}{BLACK_SQUARE}{RESET}" + if img_token in words[j] and img_token in words[i] + else f"{GREEN}{BLACK_SQUARE}{RESET}" + if i == j + else BLACK_SQUARE + if sliding_window_mask[i][j] + else WHITE_SQUARE + for j in range(n) + ) + + output.append(f"{colored_word}: {str(i).rjust(2)} {row_display} | {sliding_window_row}") + + return "\n".join(output) + + +class AttentionMaskVisualizer: + def __init__(self, model_name: str): + config = AutoConfig.from_pretrained(model_name) + self.image_token = "" + if hasattr(config.get_text_config(), "sliding_window"): + config.sliding_window = 5 + try: + mapped_cls = _get_model_class(config, MODEL_MAPPING) + except Exception: + mapped_cls = _get_model_class(config, MODEL_FOR_PRETRAINING_MAPPING) + + if mapped_cls is None: + raise ValueError(f"Model name {model_name} is not supported for attention visualization") + self.mapped_cls = mapped_cls + + class _ModelWrapper(mapped_cls, nn.Module): + def __init__(self, config, model_name): + nn.Module.__init__(self) + self.dummy_module = nn.Linear(1, 1) + self.config = config + + self.model = _ModelWrapper(config, model_name) + self.model.to(config.torch_dtype) + self.repo_id = model_name + self.config = config + + def __call__(self, input_sentence: str, suffix=""): + self.visualize_attention_mask(input_sentence, suffix=suffix) + + def visualize_attention_mask(self, input_sentence: str, suffix=""): + model = self.model + kwargs = {} + if self.config.model_type in PROCESSOR_MAPPING_NAMES: + img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true" + img = Image.open(requests.get(img, stream=True).raw) + processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5) + if hasattr(processor, "image_token"): + image_token = processor.image_token + else: + image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0] + + if image_token: + input_sentence = input_sentence.replace("", image_token) + + inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt") + + self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0] + + attention_mask = inputs["attention_mask"] + if "token_type_ids" in inputs: # TODO inspect signature of update causal mask + kwargs["token_type_ids"] = inputs["token_type_ids"] + tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) + elif self.config.model_type in TOKENIZER_MAPPING_NAMES: + tokenizer = AutoTokenizer.from_pretrained(self.repo_id) + tokens = tokenizer.tokenize(input_sentence) + attention_mask = tokenizer(input_sentence, return_tensors="pt")["attention_mask"] + else: + raise ValueError(f"Model type {model.config.model_type} does not support attention visualization") + + model.config._attn_implementation = "eager" + model.train() + attention_mask = ~model._update_causal_mask( + attention_mask=attention_mask, + input_tensor=attention_mask.to(self.model.dtype), + cache_position=torch.arange(attention_mask.shape[1]), + past_key_values=None, + **kwargs, + ).bool() + top_bottom_border = "##" * ( + len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4 + ) # Box width adjusted to text length + side_border = "##" + print(f"\n{top_bottom_border}") + print( + "##" + + f" Attention visualization for \033[1m{self.config.model_type}:{self.repo_id}\033[0m {self.mapped_cls.__name__}".center( + len(top_bottom_border) + ) + + " " + + side_border + ) + print(f"{top_bottom_border}") + f_string = generate_attention_matrix_from_mask( + tokens, + attention_mask, + img_token=self.image_token, + sliding_window=getattr(self.config, "sliding_window", None), + token_type_ids=kwargs.get("token_type_ids", None), + ) + print(f_string) + print(f"{top_bottom_border}")