From 3a4ae6eace078c6d3c0f064b246cf9bde8978812 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 9 Jan 2025 17:54:57 +0100 Subject: [PATCH] Refactor/fix Cohere2 (#35594) * refactor/fix cohere2 * add kwargs * tests * remove func and import it --- .../models/cohere2/modeling_cohere2.py | 266 ++++++------------ .../models/cohere2/modular_cohere2.py | 248 ++++------------ .../models/gemma2/modeling_gemma2.py | 2 + .../models/gemma2/modular_gemma2.py | 2 + tests/models/cohere2/test_modeling_cohere2.py | 5 - 5 files changed, 146 insertions(+), 377 deletions(-) diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 9c8a8891e1..8fa99c06a0 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -31,23 +30,18 @@ from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_cohere2 import Cohere2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Cohere2Config" @@ -139,6 +133,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + def rotate_half(x): # Split and rotate. Note that this function is different from e.g. Llama. x1 = x[..., ::2] @@ -177,120 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) -def eager_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim) - - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def flash_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - ) - - return attn_output, None - - -def sdpa_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -COHERE2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - class Cohere2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -298,34 +204,24 @@ class Cohere2Attention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.sliding_window = ( config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None ) @@ -334,25 +230,19 @@ class Cohere2Attention(nn.Module): self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - if self.sliding_window is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -365,23 +255,31 @@ class Cohere2Attention(nn.Module): } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`") - attention_type = "eager" - else: - attention_type = self.config._attn_implementation + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Cohere2MLP(nn.Module): @@ -416,10 +314,11 @@ class Cohere2DecoderLayer(nn.Module): hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -430,13 +329,13 @@ class Cohere2DecoderLayer(nn.Module): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence """ @@ -460,7 +359,7 @@ class Cohere2DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states_attention, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -468,6 +367,7 @@ class Cohere2DecoderLayer(nn.Module): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) # Fully Connected @@ -481,9 +381,6 @@ class Cohere2DecoderLayer(nn.Module): if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -653,6 +550,7 @@ class Cohere2Model(Cohere2PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -727,6 +625,7 @@ class Cohere2Model(Cohere2PreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -740,16 +639,13 @@ class Cohere2Model(Cohere2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 3e6999b29b..145905287a 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -13,8 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,30 +21,29 @@ import torch.utils.checkpoint from ...cache_utils import Cache, HybridCache from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, ) from ...modeling_rope_utils import rope_config_validation +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import ( - is_flash_attn_2_available, logging, ) from ..cohere.modeling_cohere import ( + CohereAttention, CohereDecoderLayer, CohereForCausalLM, CohereLayerNorm, CoherePreTrainedModel, CohereRotaryEmbedding, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) from ..gemma2.modeling_gemma2 import Gemma2Model -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) @@ -240,155 +238,31 @@ class Cohere2LayerNorm(CohereLayerNorm): pass -def eager_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim) - - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def flash_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - ) - - return attn_output, None - - -def sdpa_attention_forward( - config: Cohere2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -COHERE2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - -class Cohere2Attention(nn.Module): +class Cohere2Attention(CohereAttention, nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): - super().__init__() + nn.Module.__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta self.is_causal = True - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.sliding_window = ( config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None ) @@ -397,25 +271,19 @@ class Cohere2Attention(nn.Module): self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings - if self.sliding_window is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -428,23 +296,31 @@ class Cohere2Attention(nn.Module): } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`") - attention_type = "eager" - else: - attention_type = self.config._attn_implementation + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Cohere2DecoderLayer(CohereDecoderLayer): @@ -460,10 +336,11 @@ class Cohere2DecoderLayer(CohereDecoderLayer): hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -474,13 +351,13 @@ class Cohere2DecoderLayer(CohereDecoderLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence """ @@ -504,7 +381,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states_attention, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -512,6 +389,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) # Fully Connected @@ -525,9 +403,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer): if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -559,6 +434,7 @@ class Cohere2Model(Gemma2Model): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -633,6 +509,7 @@ class Cohere2Model(Gemma2Model): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -646,16 +523,13 @@ class Cohere2Model(Gemma2Model): if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() class Cohere2ForCausalLM(CohereForCausalLM): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index a1f6897661..8b995d1a08 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -548,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -633,6 +634,7 @@ class Gemma2Model(Gemma2PreTrainedModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 48b1241136..f73b9ea840 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -378,6 +378,7 @@ class Gemma2Model(GemmaModel): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -463,6 +464,7 @@ class Gemma2Model(GemmaModel): output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 8e1a4834d1..144846772f 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -201,7 +201,6 @@ class Cohere2IntegrationTest(unittest.TestCase): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @require_read_token - @unittest.skip("Cohere2 has not been released yet") def test_model_bf16(self): model_id = "CohereForAI/command-r7b-12-2024" EXPECTED_TEXTS = [ @@ -222,7 +221,6 @@ class Cohere2IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) @require_read_token - @unittest.skip("Cohere2 has not been released yet") def test_model_fp16(self): model_id = "CohereForAI/command-r7b-12-2024" EXPECTED_TEXTS = [ @@ -243,7 +241,6 @@ class Cohere2IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) @require_read_token - @unittest.skip("Cohere2 has not been released yet") def test_model_pipeline_bf16(self): # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Cohere2 before this PR model_id = "CohereForAI/command-r7b-12-2024" @@ -269,7 +266,6 @@ class Cohere2IntegrationTest(unittest.TestCase): @require_torch_gpu @mark.flash_attn_test @slow - @unittest.skip("Cohere2 has not been released yet") def test_model_flash_attn(self): # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context model_id = "CohereForAI/command-r7b-12-2024" @@ -291,7 +287,6 @@ class Cohere2IntegrationTest(unittest.TestCase): @slow @require_read_token - @unittest.skip("Cohere2 has not been released yet") def test_export_static_cache(self): if version.parse(torch.__version__) < version.parse("2.5.0"): self.skipTest(reason="This test requires torch >= 2.5 to run.")