🚨All attention refactor🚨 (#35235)

* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2024-12-18 16:53:39 +01:00
committed by GitHub
parent 75be5a0a5b
commit 2c47618c1a
107 changed files with 5934 additions and 10077 deletions

View File

@@ -484,7 +484,7 @@ class TFModelTesterMixin:
return new_tf_outputs, new_pt_outputs
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None):
"""Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way.
Args:
@@ -495,6 +495,7 @@ class TFModelTesterMixin:
attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
being a named field in the output.
"""
from transformers.cache_utils import DynamicCache
self.assertEqual(type(name), str)
if attributes is not None:
@@ -540,6 +541,8 @@ class TFModelTesterMixin:
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
if isinstance(pt_output, DynamicCache):
pt_output = pt_output.to_legacy_cache()
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(tf_outputs, tf.Tensor):