[Feat] Support npu in modeling models (#37369)

This commit is contained in:
duanjunwen
2025-04-11 01:00:58 +08:00
committed by GitHub
parent 10907e2846
commit 7ff896c0f2
65 changed files with 70 additions and 66 deletions

View File

@@ -1040,7 +1040,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu"]
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when