[FlexAttention] Update gemma2 (#34942)

* update tests

* now maybe this fixes the previous fialing tests!

* nit default

* Update src/transformers/models/gemma2/modular_gemma2.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* fix-copies

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
Arthur
2024-11-27 11:50:48 +01:00
committed by GitHub
parent 6c3f168b36
commit 4c1388f48e
3 changed files with 15 additions and 7 deletions

View File

@@ -247,9 +247,12 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
attn_weights = None
else:
return attn_output[0], attn_output[1]
attn_output, attn_weights = attn_output
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
@@ -280,6 +283,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
is_causal=is_causal,
scale=config.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
@@ -362,7 +366,7 @@ class Gemma2Attention(nn.Module):
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
attention_type = "flex_attention"
else:
attention_type = self.config._attn_implementation

View File

@@ -290,9 +290,12 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
return_lse=output_attentions,
)
if not output_attentions:
return attn_output, None
attn_weights = None
else:
return attn_output[0], attn_output[1]
attn_output, attn_weights = attn_output
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
@@ -323,6 +326,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
is_causal=is_causal,
scale=config.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
@@ -405,7 +409,7 @@ class Gemma2Attention(nn.Module):
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
attention_type = "eager"
attention_type = "flex_attention"
else:
attention_type = self.config._attn_implementation

View File

@@ -385,7 +385,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
assert model.config._attn_implementation == "flex_attention"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)