Adding Flash Attention 2 Support for GPT2 (#29226)

* First commit to add flash attention 2 for GPT-2

* more improvements

* Make GPT2 pass tests and fixed Decison Transformers copies

* Fixed missing arg

* fix copies

* Added expected speedup

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Added test

* Fixed attn attribute

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/model_doc/gpt2.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update Decision transformer attentions

* More updates

* Passing tests

* Fix copies

* Fix copies part 2

* Decision transformer updates

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix copies

* Decision transformer not supporting flash attn

* Addressed comments

* Addressed comments

* Addressed comments

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Eduardo Pacheco
2024-03-28 10:31:24 +01:00
committed by GitHub
parent 3a7e68362b
commit 22d159ddf9
5 changed files with 376 additions and 24 deletions

View File

@@ -19,8 +19,17 @@ import gc
import math
import unittest
import pytest
from transformers import GPT2Config, is_torch_available
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
from transformers.testing_utils import (
backend_empty_cache,
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -858,3 +867,40 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
"but said in a statement to The Associated Press that"
],
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_left(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = GPT2LMHeadModel.from_pretrained(
"gpt2", device_map={"": 0}, attn_implementation="flash_attention_2", torch_dtype=torch.float16
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
expected_output = [
"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
"Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
]
self.assertListEqual(output_native, output_fa_2)
self.assertListEqual(output_native, expected_output)