[FlexAttention] Update gemma2 (#34942)
* update tests * now maybe this fixes the previous fialing tests! * nit default * Update src/transformers/models/gemma2/modular_gemma2.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * fix-copies --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
This commit is contained in:
@@ -385,7 +385,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
|
||||
assert model.config._attn_implementation == "flex_attention"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user