Refactor flash attention implementation in transformers (#31446)
* dumb commit * nit * update * something like this * unpack in modeling utils * safe import * oups * update * nits * diff convert gemma * update * start propagating * udpate other modeling code as well * update for sliding window models * nits * more init cleanups * styling * fixup * noice * pass fixup * typo typing_extension -> typing_extensions * torch.nn.functionnal -> torch.nn.functional * add to import structure * unpack * simplify a bit more for this first version * nut * update * update * nit * ease the import of `Unpack` * remove useless `use_sliding_window` * no qua please * protect import? * style * [run-slow] * [run slow] llama,gemma,mistral,mixtral * remove extra kwargs * fix llama * address review comments * apply diff_model_converter to modeling_gemma.py * remove cache_position 1 * remove cache_position 2 * some cleaning * refactor gemma2 as well * apply review comments * rename file to modeling_flash_attention_utils.py * siglip refactor * remove dead code * is the hub down? * still down? * fix siglip * fix gemma2 * fatal: Could not read from remote repository. * fix typo in softcap implem * flacky * Failed: Timeout >120.0s --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,8 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"expert_layer_offset",
|
||||
"expert_layer_period",
|
||||
],
|
||||
"Qwen2Config": ["use_sliding_window"],
|
||||
"Qwen2MoeConfig": ["use_sliding_window"],
|
||||
"Gemma2Config": ["tie_word_embeddings"],
|
||||
# used to compute the property `self.chunk_length`
|
||||
"EncodecConfig": ["overlap"],
|
||||
|
||||
Reference in New Issue
Block a user