🚨All attention refactor🚨 (#35235)

* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2024-12-18 16:53:39 +01:00
committed by GitHub
parent 75be5a0a5b
commit 2c47618c1a
107 changed files with 5934 additions and 10077 deletions

View File

@@ -563,32 +563,17 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
model = AutoModelForCausalLM.from_pretrained(
TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation
)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_model_from_config_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
@@ -602,11 +587,6 @@ class ModelUtilsTest(TestCasePlus):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
@@ -614,11 +594,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
@@ -626,11 +601,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(config._attn_implementation_internal, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
@@ -638,11 +608,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [