🚨All attention refactor🚨 (#35235)

* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2024-12-18 16:53:39 +01:00
committed by GitHub
parent 75be5a0a5b
commit 2c47618c1a
107 changed files with 5934 additions and 10077 deletions

View File

@@ -453,11 +453,9 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon
def test_model_rope_scaling(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
head_dim = hidden_size // num_heads
scaling_factor = 10
short_input_length = 10
long_input_length = int(config.max_position_embeddings * 1.5)
@@ -470,11 +468,7 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
position_ids_long = position_ids_long.unsqueeze(0)
# Sanity check original RoPE
original_rope = FalconRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
).to(torch_device)
original_rope = FalconRotaryEmbedding(config).to(torch_device)
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
@@ -482,13 +476,8 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# Sanity check linear RoPE scaling
# New position "x" should match original position with index "x/scaling_factor"
linear_scaling_rope = FalconRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
rope_type="linear",
).to(torch_device)
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
@@ -501,13 +490,8 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
# Sanity check Dynamic NTK RoPE scaling
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
# with scaling_factor (or that `inv_freq` decreases)
ntk_scaling_rope = FalconRotaryEmbedding(
head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta,
scaling_factor=scaling_factor,
rope_type="dynamic",
).to(torch_device)
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
torch.testing.assert_close(ntk_cos_short, original_cos_short)