[Flash Attention 2] Add flash attention 2 for GPT-J (#28295)
* initial implementation of flash attention for gptj * modify flash attention and overwrite test_flash_attn_2_generate_padding_right * update flash attention support list * remove the copy line in the `CodeGenBlock` * address copy mechanism * Update src/transformers/models/gptj/modeling_gptj.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add GPTJ attention classes * add expected outputs in the gptj test * Ensure repo consistency with 'make fix-copies' --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,18 @@
|
||||
import datetime
|
||||
import unittest
|
||||
|
||||
from transformers import GPTJConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
||||
import pytest
|
||||
|
||||
from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -518,6 +528,44 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
expected_outputs = [
|
||||
"hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
|
||||
"Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
|
||||
]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model = GPTJForCausalLM.from_pretrained(
|
||||
"EleutherAI/gpt-j-6b",
|
||||
device_map={"": 0},
|
||||
attn_implementation="flash_attention_2",
|
||||
revision="float16",
|
||||
torch_dtype=torch.float16,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(expected_outputs, output_fa_2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTJModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user