make LlamaModel._update_causal_mask torch compilable (#35187)

* make LlamaModel._update_causal_mask torch compilable

* chore: lint (make fix-copies)

* fix-copies

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Wing Lian
2024-12-23 07:10:00 -05:00
committed by GitHub
parent 401aa39d7b
commit 5e7aedebeb
33 changed files with 33 additions and 33 deletions

View File

@@ -1012,7 +1012,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -740,7 +740,7 @@ class BloomModel(BloomPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1385,7 +1385,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -583,7 +583,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -910,7 +910,7 @@ class CohereModel(CoherePreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1111,7 +1111,7 @@ class DbrxModel(DbrxPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -633,7 +633,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -644,7 +644,7 @@ class GlmModel(GlmPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -792,7 +792,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -931,7 +931,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -667,7 +667,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -891,7 +891,7 @@ class GPTJModel(GPTJPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -646,7 +646,7 @@ class GraniteModel(GranitePreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1362,7 +1362,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1126,7 +1126,7 @@ class JetMoeModel(JetMoePreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -632,7 +632,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1600,7 +1600,7 @@ class LongT5Stack(LongT5PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1076,7 +1076,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1192,7 +1192,7 @@ class MT5Stack(MT5PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -878,7 +878,7 @@ class NemotronModel(NemotronPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -608,7 +608,7 @@ class OlmoModel(OlmoPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -609,7 +609,7 @@ class Olmo2Model(Olmo2PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -683,7 +683,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -606,7 +606,7 @@ class PhiModel(PhiPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1587,7 +1587,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1000,7 +1000,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -617,7 +617,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -938,7 +938,7 @@ class StableLmModel(StableLmPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1136,7 +1136,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1205,7 +1205,7 @@ class T5Stack(T5PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1538,7 +1538,7 @@ class UdopStack(UdopPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -849,7 +849,7 @@ class UMT5Stack(UMT5PreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

View File

@@ -1375,7 +1375,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None