🚨All attention refactor🚨 (#35235)

* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2024-12-18 16:53:39 +01:00
committed by GitHub
parent 75be5a0a5b
commit 2c47618c1a
107 changed files with 5934 additions and 10077 deletions

View File

@@ -119,6 +119,7 @@ if is_torch_available():
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.cache_utils import DynamicCache
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage
@@ -1285,6 +1286,11 @@ class ModelTesterMixin:
)
for i in range(model.config.num_hidden_layers)
)
empty_pkv = (
DynamicCache.from_legacy_cache(empty_pkv)
if model_class._supports_cache_class
else empty_pkv
)
cache_length = 9
cache_shape = (batch_size, num_heads, cache_length, head_dim)
@@ -1295,6 +1301,11 @@ class ModelTesterMixin:
)
for i in range(model.config.num_hidden_layers)
)
non_empty_pkv = (
DynamicCache.from_legacy_cache(non_empty_pkv)
if model_class._supports_cache_class
else non_empty_pkv
)
inps = copy.deepcopy(inputs_to_test[0])
@@ -2471,7 +2482,7 @@ class ModelTesterMixin:
return new_tf_outputs, new_pt_outputs
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
@@ -2527,6 +2538,8 @@ class ModelTesterMixin:
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):
@@ -2702,7 +2715,7 @@ class ModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@@ -2712,7 +2725,6 @@ class ModelTesterMixin:
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
@@ -2757,6 +2769,8 @@ class ModelTesterMixin:
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
@@ -3881,15 +3895,6 @@ class ModelTesterMixin:
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
@@ -3942,15 +3947,6 @@ class ModelTesterMixin:
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]):
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):