Refactor flash attention implementation in transformers (#31446)

* dumb commit

* nit

* update

* something like this

* unpack in modeling utils

* safe import

* oups

* update

* nits

* diff convert gemma

* update

* start propagating

* udpate other modeling code as well

* update for sliding window models

* nits

* more init cleanups

* styling

* fixup

* noice

* pass fixup

* typo typing_extension -> typing_extensions

* torch.nn.functionnal -> torch.nn.functional

* add to import structure

* unpack

* simplify a bit more for this first version

* nut

* update

* update

* nit

* ease the import of `Unpack`

* remove useless `use_sliding_window`

* no qua please

* protect import?

* style

* [run-slow]

* [run slow] llama,gemma,mistral,mixtral

* remove extra kwargs

* fix llama

* address review comments

* apply diff_model_converter to modeling_gemma.py

* remove cache_position 1

* remove cache_position 2

* some cleaning

* refactor gemma2 as well

* apply review comments

* rename file to modeling_flash_attention_utils.py

* siglip refactor

* remove dead code

* is the hub down?

* still down?

* fix siglip

* fix gemma2

* fatal: Could not read from remote repository.

* fix typo in softcap implem

* flacky

* Failed: Timeout >120.0s

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
This commit is contained in:
Arthur
2024-07-11 14:37:31 +02:00
committed by GitHub
parent ad4ef3a290
commit e314395277
49 changed files with 792 additions and 5365 deletions

View File

@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = {
"expert_layer_offset",
"expert_layer_period",
],
"Qwen2Config": ["use_sliding_window"],
"Qwen2MoeConfig": ["use_sliding_window"],
"Gemma2Config": ["tie_word_embeddings"],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],