Adding Flash Attention 2 Support for GPT2 (#29226)
* First commit to add flash attention 2 for GPT-2 * more improvements * Make GPT2 pass tests and fixed Decison Transformers copies * Fixed missing arg * fix copies * Added expected speedup * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Added test * Fixed attn attribute * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/gpt2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update Decision transformer attentions * More updates * Passing tests * Fix copies * Fix copies part 2 * Decision transformer updates * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix copies * Decision transformer not supporting flash attn * Addressed comments * Addressed comments * Addressed comments --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -19,8 +19,17 @@ import gc
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import GPT2Config, is_torch_available
|
||||
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -858,3 +867,40 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
"but said in a statement to The Associated Press that"
|
||||
],
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_left(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
"gpt2", device_map={"": 0}, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
expected_output = [
|
||||
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
|
||||
"Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
|
||||
]
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
self.assertListEqual(output_native, expected_output)
|
||||
|
||||
Reference in New Issue
Block a user