🚨All attention refactor🚨 (#35235)
* refactor LlamaAttention * minimal changes * fix llama * update * modular gemmas * modular nits * modular updates * nits * simplify * gpt2 * more modualr and fixes * granite * modular modular modular * nits * update * qwen2 + starcoder2 * mostly gemma2 * Update image_processing_auto.py * fix * Update modular_starcoder2.py * fix * remove all copied from attentions * remove gcv * make fix-copies * oups * oups2.0 * fix some modulars + all copied from * should be good now * revert unwanted changes * Update modeling_decision_transformer.py * finish cleanup * Update modeling_olmo.py * consistency * re-add gradient checkpointing attribute * fix * style * make config necessary * bis * bis * Update modeling_my_new_model2.py * is_causal attr * fix * remove past kv return from decoder layer * fix * default rope config * correctly fix rope config * fix bias * fix gpt2 attention output * fix test * fix inits * fix default sdpa * fix default sdpa implementation * harmonize classes * fix mistral * fix sliding window models * mixtral * be more explicit * style * fix * several fixes * Update modeling_dbrx.py * fix test * olmo + phi * rotary * syle * phi * phi again * again * kwargs * Update test_modeling_common.py * skip fx tracing tests * Update modeling_utils.py * gemma 2 * again * Update modeling_recurrent_gemma.py * gemma2 * granite * style * starcoder * Update sdpa_attention.py * switch args * Update modeling_mllama.py * fix * cache type tests * gpt2 * Update test_modeling_common.py * fix * consistency * fix shape with encoder * should be the last one * tests non model * most comments * small oupsi * be more explicit in modulars * more explicit modulars * CIs! it works locally * add kwargs to _flash_attention_forward --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import ClassVar, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@@ -253,7 +253,14 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def _update_causal_mask(
|
||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||
self,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
input_ids=None,
|
||||
inputs_embeds=None,
|
||||
is_training: bool = False,
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
@@ -261,11 +268,13 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
return None
|
||||
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = inputs_embeds.shape[1]
|
||||
min_dtype = torch.finfo(self.dtype).min
|
||||
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
||||
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
elif isinstance(past_key_values, HybridCache):
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
@@ -278,7 +287,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
return attention_mask
|
||||
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
||||
)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
@@ -288,7 +297,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
@@ -317,7 +326,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||
@@ -414,6 +423,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=None,
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
@@ -433,12 +443,16 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
# position_ids in NewTaskModel are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
||||
if cache_position[0] == 0:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
return model_inputs
|
||||
|
||||
def resize_token_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user