Add attention visualization tool (#36630)
* add utils fiel * style * nits * nits * update * updaets * update * fix init issues * big updates * nits * nits? * small updates * nites * there were still some models left * style * fixes * updates * nits _ fixes * push changes * update * update * update * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * style * styling and return a string for testing * small updates * always biderectional for now * update --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
@@ -1015,7 +1015,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
|
||||
Reference in New Issue
Block a user