From bd442c6d3aa6298ffc6570574741746439261294 Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Mon, 6 Jan 2025 23:08:45 -1000 Subject: [PATCH] Zamba new attention standard (#35375) * updated zamba to new attention standard * make fixup fixes --- .../models/zamba/modeling_zamba.py | 390 +++++------------- tests/models/zamba/test_modeling_zamba.py | 6 +- 2 files changed, 104 insertions(+), 292 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index bb2638ee91..761c799bdc 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -20,7 +20,7 @@ """PyTorch Zamba model.""" import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -33,18 +33,18 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): +class ZambaHybridDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -131,9 +131,9 @@ class HybridMambaAttentionDynamicCache(DynamicCache): self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv self.n_mamba_heads = config.n_mamba_heads self.conv_states = [] self.ssm_states = [] @@ -143,9 +143,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache): self._buffers = {} for i in range(config.num_hidden_layers): self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) ] - cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + cache_shape = ( + batch_size, + self.n_mamba_heads, + self.intermediate_size // self.n_mamba_heads, + self.ssm_state_size, + ) self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) @@ -194,14 +199,38 @@ class HybridMambaAttentionDynamicCache(DynamicCache): return 0 return self.key_cache[layer_idx].shape[-2] - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") @classmethod - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights class ZambaAttention(nn.Module): @@ -218,277 +247,67 @@ class ZambaAttention(nn.Module): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) """ - def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: ZambaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - self.hidden_size = config.hidden_size self.attention_hidden_size = config.attention_hidden_size - self.num_heads = config.num_attention_heads self.head_dim = config.attention_head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.scaling = (self.head_dim / 2) ** -0.5 self.is_causal = True self.attention_dropout = config.attention_dropout - if (self.head_dim * self.num_heads) != self.attention_hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_proj = nn.Linear(config.attention_hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[ZambaHybridDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) - - attn_output = attn_output - attn_output = self.o_proj(attn_output) - attn_output = attn_output - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward -# dropped use_sliding_windows from the arguments of self._flash_attention_forward -class ZambaFlashAttention2(ZambaAttention): - """ - Zamba flash attention module. This module inherits from `ZambaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=softmax_scale, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention -class ZambaSdpaAttention(ZambaAttention): - """ - Zamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `ZambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "ZambaModel is using ZambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - scale=softmax_scale, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -ZAMBA_ATTENTION_CLASSES = { - "eager": ZambaAttention, - "flash_attention_2": ZambaFlashAttention2, - "sdpa": ZambaSdpaAttention, -} + return attn_output, attn_weights class ZambaMambaMixer(nn.Module): @@ -568,7 +387,7 @@ class ZambaMambaMixer(nn.Module): ) def cuda_kernels_forward( - self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None + self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 @@ -664,7 +483,7 @@ class ZambaMambaMixer(nn.Module): contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated linear projection @@ -675,7 +494,7 @@ class ZambaMambaMixer(nn.Module): gate = gate.squeeze(2) gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) + use_cache = isinstance(cache_params, ZambaHybridDynamicCache) # 2. Convolution sequence transformation if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: if self.training: @@ -757,7 +576,7 @@ class ZambaMambaMixer(nn.Module): ) return contextualized_states - def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): if self.use_fast_kernels: if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type: raise ValueError( @@ -789,7 +608,7 @@ class ZambaMLP(nn.Module): class ZambaAttentionDecoderLayer(nn.Module): def __init__(self, config: ZambaConfig, layer_idx: Optional[int] = None): super().__init__() - self.self_attn = ZAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = ZambaAttention(config, layer_idx) self.feed_forward = ZambaMLP(config) self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) @@ -802,11 +621,11 @@ class ZambaAttentionDecoderLayer(nn.Module): layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -815,9 +634,11 @@ class ZambaAttentionDecoderLayer(nn.Module): This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The concatenated tensor is then used as input of the pre-attention RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). + layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings. + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -829,7 +650,7 @@ class ZambaAttentionDecoderLayer(nn.Module): """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, @@ -849,9 +670,6 @@ class ZambaAttentionDecoderLayer(nn.Module): if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -870,7 +688,7 @@ class ZambaMambaDecoderLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -881,7 +699,7 @@ class ZambaMambaDecoderLayer(nn.Module): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -923,7 +741,7 @@ class ZambaMambaDecoderLayer(nn.Module): return outputs -class HybridLayer(nn.Module): +class ZambaHybridLayer(nn.Module): def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): super().__init__() self.shared_transf = shared_transf @@ -938,7 +756,7 @@ class HybridLayer(nn.Module): attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -951,7 +769,7 @@ class HybridLayer(nn.Module): layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1027,7 +845,7 @@ class ZambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = False - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache _is_stateful = True def _init_weights(self, module): @@ -1121,14 +939,14 @@ ZAMBA_INPUTS_DOCSTRING = r""" config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `ZambaHybridDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1202,7 +1020,7 @@ class ZambaModel(ZambaPreTrainedModel): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) @@ -1226,7 +1044,7 @@ class ZambaModel(ZambaPreTrainedModel): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1263,7 +1081,7 @@ class ZambaModel(ZambaPreTrainedModel): if use_cache and past_key_values is None: logger.warning_once( - "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1324,17 +1142,13 @@ class ZambaModel(ZambaPreTrainedModel): if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True - next_cache = None if not use_cache else past_key_values - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask def _update_causal_mask(self, attention_mask, input_tensor, cache_position): @@ -1410,7 +1224,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1504,7 +1318,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): use_cache=True, **kwargs, ): - # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + # Overwitten -- has a unique cache type, `ZambaHybridDynamicCache` empty_past_kv = past_key_values is None @@ -1518,7 +1332,7 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = HybridMambaAttentionDynamicCache( + past_key_values = ZambaHybridDynamicCache( self.config, input_ids.shape[0], dtype=self.dtype, device=self.device ) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index a6dd516f98..ee47f98a1f 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -46,7 +46,7 @@ if is_torch_available(): ZambaModel, ) from transformers.models.zamba.modeling_zamba import ( - HybridMambaAttentionDynamicCache, + ZambaHybridDynamicCache, ) @@ -215,9 +215,7 @@ class ZambaModelTester: # first forward pass # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) + past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask,