Gemma2 and flash-attention (#32188)
* enable flash-attn & static cache * this works, not the prev * fix for sliding window layers * not needed anymore
This commit is contained in:
committed by
GitHub
parent
a3264332cf
commit
7f552e28e0
@@ -16,8 +16,11 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -161,3 +164,28 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
|
||||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
|
||||
|
||||
@require_read_token
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_model_9b_flash_attn(self):
|
||||
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
|
||||
] # fmt: skip
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation="flash_attention_2", torch_dtype="float16"
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
print(output_text)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
Reference in New Issue
Block a user