[FlexAttention] Update gemma2 (#34942)
* update tests * now maybe this fixes the previous fialing tests! * nit default * Update src/transformers/models/gemma2/modular_gemma2.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * fix-copies --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
@@ -247,9 +247,12 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
attn_weights = None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
attn_output, attn_weights = attn_output
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
@@ -280,6 +283,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@@ -362,7 +366,7 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
|
||||
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
|
||||
attention_type = "eager"
|
||||
attention_type = "flex_attention"
|
||||
else:
|
||||
attention_type = self.config._attn_implementation
|
||||
|
||||
|
||||
@@ -290,9 +290,12 @@ def flex_attention_forward(config, query, key, value, mask, output_attentions=Fa
|
||||
return_lse=output_attentions,
|
||||
)
|
||||
if not output_attentions:
|
||||
return attn_output, None
|
||||
attn_weights = None
|
||||
else:
|
||||
return attn_output[0], attn_output[1]
|
||||
attn_output, attn_weights = attn_output
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
@@ -323,6 +326,7 @@ def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
|
||||
is_causal=is_causal,
|
||||
scale=config.scaling,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
||||
|
||||
|
||||
@@ -405,7 +409,7 @@ class Gemma2Attention(nn.Module):
|
||||
|
||||
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]:
|
||||
logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`")
|
||||
attention_type = "eager"
|
||||
attention_type = "flex_attention"
|
||||
else:
|
||||
attention_type = self.config._attn_implementation
|
||||
|
||||
|
||||
@@ -385,7 +385,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
|
||||
assert model.config._attn_implementation == "flex_attention"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user