Add attention visualization tool (#36630)

* add utils  fiel

* style

* nits

* nits

* update

* updaets

* update

* fix init issues

* big updates

* nits

* nits?

* small updates

* nites

* there were still some models left

* style

* fixes

* updates

* nits _ fixes

* push changes

* update

* update

* update

* Apply suggestions from code review

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* style

* styling and return a string for testing

* small updates

* always biderectional for now

* update

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Arthur
2025-03-19 13:58:46 +01:00
committed by GitHub
parent 0fe0bae0a8
commit fef8b7f8e9
55 changed files with 294 additions and 59 deletions

View File

@@ -1015,7 +1015,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():