[Flash Attention 2] Add flash attention 2 for GPT-J (#28295)

* initial implementation of flash attention for gptj

* modify flash attention and overwrite test_flash_attn_2_generate_padding_right

* update flash attention support list

* remove the copy line in the `CodeGenBlock`

* address copy mechanism

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add GPTJ attention classes

* add expected outputs in the gptj test

* Ensure repo consistency with 'make fix-copies'

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
bytebarde
2024-03-13 01:43:00 -06:00
committed by GitHub
parent d522afea13
commit be3fd8a262
4 changed files with 349 additions and 22 deletions

View File

@@ -17,8 +17,18 @@
import datetime
import unittest
from transformers import GPTJConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
import pytest
from transformers import BitsAndBytesConfig, GPTJConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
tooslow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -518,6 +528,44 @@ class GPTJModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = GPTJModel.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16)
self.assertIsNotNone(model)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
texts = ["hi", "Hello this is a very long sentence"]
expected_outputs = [
"hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
"Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = GPTJForCausalLM.from_pretrained(
"EleutherAI/gpt-j-6b",
device_map={"": 0},
attn_implementation="flash_attention_2",
revision="float16",
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(expected_outputs, output_fa_2)
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):