make LlamaModel._update_causal_mask torch compilable (#35187)

* make LlamaModel._update_causal_mask torch compilable

* chore: lint (make fix-copies)

* fix-copies

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Wing Lian
2024-12-23 07:10:00 -05:00
committed by GitHub
parent 401aa39d7b
commit 5e7aedebeb
33 changed files with 33 additions and 33 deletions

View File

@@ -1012,7 +1012,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None