🚨All attention refactor🚨 (#35235)

* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
Arthur
2024-12-18 16:53:39 +01:00
committed by GitHub
parent 75be5a0a5b
commit 2c47618c1a
107 changed files with 5934 additions and 10077 deletions

View File

@@ -10,7 +10,7 @@ from typing import ClassVar, List, Optional, Tuple, Union
import torch
from torch import nn
from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import (
@@ -253,7 +253,14 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
return self.language_model.tie_weights()
def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
self,
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_ids=None,
inputs_embeds=None,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
@@ -261,11 +268,13 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
return None
using_static_cache = isinstance(past_key_values, StaticCache)
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
@@ -278,7 +287,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
return attention_mask
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
@@ -288,7 +297,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
@@ -317,7 +326,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
@@ -414,6 +423,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -433,12 +443,16 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
# position_ids in NewTaskModel are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
def resize_token_embeddings(