[Flash Attention 2] Add flash attention 2 for GPT-J (#28295)

* initial implementation of flash attention for gptj

* modify flash attention and overwrite test_flash_attn_2_generate_padding_right

* update flash attention support list

* remove the copy line in the `CodeGenBlock`

* address copy mechanism

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add GPTJ attention classes

* add expected outputs in the gptj test

* Ensure repo consistency with 'make fix-copies'

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
bytebarde
2024-03-13 01:43:00 -06:00
committed by GitHub
parent d522afea13
commit be3fd8a262
4 changed files with 349 additions and 22 deletions

View File

@@ -44,6 +44,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)