[FlexAttention] Update gemma2 (#34942)

* update tests

* now maybe this fixes the previous fialing tests!

* nit default

* Update src/transformers/models/gemma2/modular_gemma2.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

* fix-copies

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
Arthur
2024-11-27 11:50:48 +01:00
committed by GitHub
parent 6c3f168b36
commit 4c1388f48e
3 changed files with 15 additions and 7 deletions

View File

@@ -385,7 +385,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
assert model.config._attn_implementation == "flex_attention"
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)