From 5e7aedebebbdee0d7eb0b8b2d771e45783dbf8c7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Dec 2024 07:10:00 -0500 Subject: [PATCH] make LlamaModel._update_causal_mask torch compilable (#35187) * make LlamaModel._update_causal_mask torch compilable * chore: lint (make fix-copies) * fix-copies --------- Co-authored-by: Arthur Zucker --- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/bloom/modeling_bloom.py | 2 +- src/transformers/models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/codegen/modeling_codegen.py | 2 +- src/transformers/models/cohere/modeling_cohere.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- src/transformers/models/idefics/modeling_idefics.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/longt5/modeling_longt5.py | 2 +- src/transformers/models/mllama/modeling_mllama.py | 2 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- src/transformers/models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/pix2struct/modeling_pix2struct.py | 2 +- src/transformers/models/pop2piano/modeling_pop2piano.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- .../models/switch_transformers/modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- src/transformers/models/whisper/modeling_whisper.py | 2 +- 33 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 6481d6f3c4..b96697bc07 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1012,7 +1012,7 @@ class AriaTextModel(AriaTextPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 086f8ce03c..9d7325c502 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -740,7 +740,7 @@ class BloomModel(BloomPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 11bc411a00..90a02dd5bb 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1385,7 +1385,7 @@ class ChameleonModel(ChameleonPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 616c93a46e..5c8f1b3957 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -583,7 +583,7 @@ class CodeGenModel(CodeGenPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 7b8b9547ac..a65d3ee64a 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -910,7 +910,7 @@ class CohereModel(CoherePreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 0d2c4297e0..3f2e7c384d 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1111,7 +1111,7 @@ class DbrxModel(DbrxPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e2ea12b03f..71cd6b6158 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -633,7 +633,7 @@ class GemmaModel(GemmaPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 95ad0d9719..706847650b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -644,7 +644,7 @@ class GlmModel(GlmPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index ef23b5d208..4e41c80d69 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -792,7 +792,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7152d72f5b..f512938e75 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -931,7 +931,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 71602f01e7..fba67ae03a 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -667,7 +667,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 4af8f73b5f..00749b7eb0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -891,7 +891,7 @@ class GPTJModel(GPTJPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 2e045e149d..7e758947b6 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -646,7 +646,7 @@ class GraniteModel(GranitePreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index b2ffbcbc69..e6b9682b5a 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1362,7 +1362,7 @@ class IdeficsModel(IdeficsPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7b7fd5a90d..a2a86fd4c2 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1126,7 +1126,7 @@ class JetMoeModel(JetMoePreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5be33c2641..df46e15bce 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -632,7 +632,7 @@ class LlamaModel(LlamaPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 29536d9ad6..15958e772c 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1600,7 +1600,7 @@ class LongT5Stack(LongT5PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 3e0c4d7a51..6523ab6812 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1076,7 +1076,7 @@ class MllamaPreTrainedModel(PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 659a84c5fe..e401753601 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1192,7 +1192,7 @@ class MT5Stack(MT5PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a0a10bdc6f..75618f1c7e 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -878,7 +878,7 @@ class NemotronModel(NemotronPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 11d3d99f4f..39bfa726de 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -608,7 +608,7 @@ class OlmoModel(OlmoPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 49ae798e7f..89b5f4abe1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -609,7 +609,7 @@ class Olmo2Model(Olmo2PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8d3c20b9ac..27712741b7 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -683,7 +683,7 @@ class PersimmonModel(PersimmonPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 477896decd..5aa038d3cc 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -606,7 +606,7 @@ class PhiModel(PhiPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 176dadd5b8..41115a058d 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1587,7 +1587,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 6a64a27e00..bb5366ef76 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1000,7 +1000,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 36fb1ddf13..5dba7594e7 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -617,7 +617,7 @@ class Qwen2Model(Qwen2PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 88dc437cdc..7214a36e9a 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -938,7 +938,7 @@ class StableLmModel(StableLmPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index b150b04eea..daeae8f9dc 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1136,7 +1136,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 9012c8db9f..fe6cfbc5c3 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1205,7 +1205,7 @@ class T5Stack(T5PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 1928ac8a5c..af21f714ef 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1538,7 +1538,7 @@ class UdopStack(UdopPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 985dc5e442..2b007cb2c7 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -849,7 +849,7 @@ class UMT5Stack(UMT5PreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index fb01823a29..21bb2c869b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1375,7 +1375,7 @@ class WhisperDecoder(WhisperPreTrainedModel): output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None