Refactor/fix Cohere2 (#35594)

* refactor/fix cohere2

* add kwargs

* tests

* remove func and import it
This commit is contained in:
Cyril Vallez
2025-01-09 17:54:57 +01:00
committed by GitHub
parent 32e0db8a69
commit 3a4ae6eace
5 changed files with 146 additions and 377 deletions

View File

@@ -19,8 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -31,23 +30,18 @@ from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
LossKwargs, LossKwargs,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_cohere2 import Cohere2Config from .configuration_cohere2 import Cohere2Config
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Cohere2Config" _CONFIG_FOR_DOC = "Cohere2Config"
@@ -139,6 +133,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def rotate_half(x): def rotate_half(x):
# Split and rotate. Note that this function is different from e.g. Llama. # Split and rotate. Note that this function is different from e.g. Llama.
x1 = x[..., ::2] x1 = x[..., ::2]
@@ -177,120 +197,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
def eager_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim)
if mask is not None: # no matter the length, we just slice it
causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
dropout_rate = config.attention_dropout if config.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
seq_len,
dropout=dropout_rate,
is_causal=config.is_causal,
sliding_window=config.sliding_window,
use_top_left_mask=config._flash_attn_uses_top_left_mask,
)
return attn_output, None
def sdpa_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)
causal_mask = mask
if mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and causal_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=config.attention_dropout if config.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
COHERE2_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class Cohere2Attention(nn.Module): class Cohere2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -298,34 +204,24 @@ class Cohere2Attention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
if layer_idx is None: self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
logger.warning_once( self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " self.scaling = self.head_dim**-0.5
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: self.q_proj = nn.Linear(
raise ValueError( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" )
f" and `num_heads`: {self.num_heads})." self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
) )
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.sliding_window = ( self.sliding_window = (
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
) )
@@ -334,25 +230,19 @@ class Cohere2Attention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings cos, sin = position_embeddings
if self.sliding_window is not None: if self.sliding_window is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -365,23 +255,31 @@ class Cohere2Attention(nn.Module):
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: attention_interface: Callable = eager_attention_forward
logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`") if self.config._attn_implementation != "eager":
attention_type = "eager" if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else: else:
attention_type = self.config._attn_implementation attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type]( attn_output, attn_weights = attention_interface(
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Cohere2MLP(nn.Module): class Cohere2MLP(nn.Module):
@@ -416,10 +314,11 @@ class Cohere2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -430,13 +329,13 @@ class Cohere2DecoderLayer(nn.Module):
attention_mask (`torch.FloatTensor`, *optional*): attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used. query_sequence_length, key_sequence_length)` if default attention is used.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
use_cache (`bool`, *optional*): use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
""" """
@@ -460,7 +359,7 @@ class Cohere2DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states_attention, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -468,6 +367,7 @@ class Cohere2DecoderLayer(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
# Fully Connected # Fully Connected
@@ -481,9 +381,6 @@ class Cohere2DecoderLayer(nn.Module):
if output_attentions: if output_attentions:
outputs += (self_attn_weights,) outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs return outputs
@@ -653,6 +550,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -727,6 +625,7 @@ class Cohere2Model(Cohere2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -740,16 +639,13 @@ class Cohere2Model(Cohere2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None output = BaseModelOutputWithPast(
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=past_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
return output if return_dict else output.to_tuple()
@torch.no_grad() @torch.no_grad()
def _update_causal_mask( def _update_causal_mask(

View File

@@ -13,8 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math from typing import Callable, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -22,30 +21,29 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
) )
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
is_flash_attn_2_available,
logging, logging,
) )
from ..cohere.modeling_cohere import ( from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer, CohereDecoderLayer,
CohereForCausalLM, CohereForCausalLM,
CohereLayerNorm, CohereLayerNorm,
CoherePreTrainedModel, CoherePreTrainedModel,
CohereRotaryEmbedding, CohereRotaryEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, eager_attention_forward,
) )
from ..gemma2.modeling_gemma2 import Gemma2Model from ..gemma2.modeling_gemma2 import Gemma2Model
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -240,155 +238,31 @@ class Cohere2LayerNorm(CohereLayerNorm):
pass pass
def eager_attention_forward( class Cohere2Attention(CohereAttention, nn.Module):
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(config.head_dim)
if mask is not None: # no matter the length, we just slice it
causal_mask = mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def flash_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
dropout_rate = config.attention_dropout if config.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
seq_len,
dropout=dropout_rate,
is_causal=config.is_causal,
sliding_window=config.sliding_window,
use_top_left_mask=config._flash_attn_uses_top_left_mask,
)
return attn_output, None
def sdpa_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)
causal_mask = mask
if mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query.device.type == "cuda" and causal_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=config.attention_dropout if config.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
COHERE2_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
}
class Cohere2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None): def __init__(self, config: Cohere2Config, layer_idx: Optional[int] = None):
super().__init__() nn.Module.__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
if layer_idx is None: self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
logger.warning_once( self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " self.scaling = self.head_dim**-0.5
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: self.q_proj = nn.Linear(
raise ValueError( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" )
f" and `num_heads`: {self.num_heads})." self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
) )
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
self.sliding_window = ( self.sliding_window = (
config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None config.sliding_window if (self.layer_idx + 1) % self.config.sliding_window_pattern != 0 else None
) )
@@ -397,25 +271,19 @@ class Cohere2Attention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings cos, sin = position_embeddings
if self.sliding_window is not None: if self.sliding_window is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -428,23 +296,31 @@ class Cohere2Attention(nn.Module):
} }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: attention_interface: Callable = eager_attention_forward
logger.warning_once("Setting `attention_type` to `eager` because `output_attentions=True`") if self.config._attn_implementation != "eager":
attention_type = "eager" if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else: else:
attention_type = self.config._attn_implementation attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = COHERE2_ATTENTION_FUNCTION[attention_type]( attn_output, attn_weights = attention_interface(
self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Cohere2DecoderLayer(CohereDecoderLayer): class Cohere2DecoderLayer(CohereDecoderLayer):
@@ -460,10 +336,11 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -474,13 +351,13 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
attention_mask (`torch.FloatTensor`, *optional*): attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used. query_sequence_length, key_sequence_length)` if default attention is used.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
use_cache (`bool`, *optional*): use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
""" """
@@ -504,7 +381,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( hidden_states_attention, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
attention_mask=attention_mask, attention_mask=attention_mask,
@@ -512,6 +389,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
# Fully Connected # Fully Connected
@@ -525,9 +403,6 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
if output_attentions: if output_attentions:
outputs += (self_attn_weights,) outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs return outputs
@@ -559,6 +434,7 @@ class Cohere2Model(Gemma2Model):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -633,6 +509,7 @@ class Cohere2Model(Gemma2Model):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -646,16 +523,13 @@ class Cohere2Model(Gemma2Model):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None output = BaseModelOutputWithPast(
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=past_key_values,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
) )
return output if return_dict else output.to_tuple()
class Cohere2ForCausalLM(CohereForCausalLM): class Cohere2ForCausalLM(CohereForCausalLM):

View File

@@ -548,6 +548,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -633,6 +634,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -378,6 +378,7 @@ class Gemma2Model(GemmaModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -463,6 +464,7 @@ class Gemma2Model(GemmaModel):
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -201,7 +201,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@require_read_token @require_read_token
@unittest.skip("Cohere2 has not been released yet")
def test_model_bf16(self): def test_model_bf16(self):
model_id = "CohereForAI/command-r7b-12-2024" model_id = "CohereForAI/command-r7b-12-2024"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
@@ -222,7 +221,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token @require_read_token
@unittest.skip("Cohere2 has not been released yet")
def test_model_fp16(self): def test_model_fp16(self):
model_id = "CohereForAI/command-r7b-12-2024" model_id = "CohereForAI/command-r7b-12-2024"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
@@ -243,7 +241,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS) self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token @require_read_token
@unittest.skip("Cohere2 has not been released yet")
def test_model_pipeline_bf16(self): def test_model_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Cohere2 before this PR # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Cohere2 before this PR
model_id = "CohereForAI/command-r7b-12-2024" model_id = "CohereForAI/command-r7b-12-2024"
@@ -269,7 +266,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test
@slow @slow
@unittest.skip("Cohere2 has not been released yet")
def test_model_flash_attn(self): def test_model_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for Gemma2, especially in long context
model_id = "CohereForAI/command-r7b-12-2024" model_id = "CohereForAI/command-r7b-12-2024"
@@ -291,7 +287,6 @@ class Cohere2IntegrationTest(unittest.TestCase):
@slow @slow
@require_read_token @require_read_token
@unittest.skip("Cohere2 has not been released yet")
def test_export_static_cache(self): def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.5.0"): if version.parse(torch.__version__) < version.parse("2.5.0"):
self.skipTest(reason="This test requires torch >= 2.5 to run.") self.skipTest(reason="This test requires torch >= 2.5 to run.")