make LlamaModel._update_causal_mask torch compilable (#35187)
* make LlamaModel._update_causal_mask torch compilable * chore: lint (make fix-copies) * fix-copies --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -1012,7 +1012,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
if attention_mask is not None and (attention_mask == 0.0).any():
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user